Change dtype in a sampling statement

Hi there.

I would like to change the data type when sampling from a distribution as part of a model. I have some code that looks like this:

with numpyro.plate('data', N):    
    x =  numpyro.sample('x', dist.Categorical(weights))

The variable x can only take a certain set of values and I have some very large datasets I would like to analyse, so I would like to change the datatype returned by this sampling statement to a uint8 to save some memory. At the moment it’s an int32. Is that possible?

Thanks :slight_smile:

I think you can subclass the Categorical distribution like CategoricalUint8 and change the dtype there.

Thanks @fehiepsi . I have redefined the distribution like this:

class CategoricalUInt8(dist.CategoricalProbs):
    
    def __init__(self, probs, validate_args=None):
        super().__init__(probs=probs, validate_args=validate_args)

    def sample(self, key, sample_shape=()):
        assert is_prng_key(key)
        return dist.util.categorical(key, self.probs, shape=sample_shape + self.batch_shape).astype(jnp.uint8)

I put a print statement in the model to check the data type and this is the output:

uint8
uint8
uint8
uint8
uint8
uint8
int32
uint8

Do you have any idea what’s going on with that second from last entry? I’m fitting using SVI.

I guess it is not sampled from the model. Are you trying to collect samples or doing inference? If you are doing inference, you can add an if/else check in your model and set breakpoint when int32 happens.

I’m doing inference so I’ve had a look with the debugger. I’m not completely sure what I’m looking for but the several times the assignment variable has the correct type it’s an ArrayImpl and then a DynamicJaxprTracer with dtype = UInt8DType. The time it flips over to an Int32 it still is a DynamicJaxprTracer but all the other model variables become JVPTracer.

Can you let me know what information I’m looking for please?

Sorry, I don’t have further suggestions. This seems to relate to some jax behavior. I would suggest to at least provide a minimal reproducible code.

1 Like

Good point. Here is an example of the problem (maybe not quite minimal, but hopefully informative). It happens in the final step of fitting with SVI, when I construct the posterior predictive distribution, the datatype is correct again. I know that it’s a problem for Dirichlet process mixture models but it might also be a problem for other models too.

The code is a bit long to C&P here so I put it in a colab notebook.

Thanks!

It seems that int32 is dtype of the enumerated value and this is part of the funsor library. Things might be fine here because the enumerated value is cheap to create and does not incur much memory (its size is the support size regardless the number of batch dimensions).

Thank you! It’s strange because looking at memory use there’s a very large allocation that happens once before the training starts, and then the memory use is pretty minimal for the whole of the training process. I have a dataset with a few million rows, and the memory allocation that breaks my training run is about 130Gb. Once it’s running it’s only using about 20Gb.

maybe this tool is helpful: Device Memory Profiling — JAX documentation

1 Like