Access to discrete variables (for Poisson distribution)

I have the following model:

def model(counts_data, p):
    
    sample_shape = counts_data.shape
            
    counts_average = numpyro.sample('counts_average', dist.Uniform(0, 5,))
    number_count = numpyro.sample('number_count', dist.Poisson(rate = counts_average), sample_shape = sample_shape)
    numpyro.sample('data', dist.Binomial(total_count = number_count, probs = p), obs = counts_data)

Some hyper parameter (called counts_average) determines the Poisson rate of the matrix number_count, the entries of which are independent. For each entry of this matrix, the counts_data variable is drawn from a Binomial distribution with (given) probability p.

I have two problems.
Problem 1:
The Poisson distribution does not have enumerate_support. Is there a way around this problem? (I am using the NUTS sampler).
Problem 2:
I can analytically marginalize over the number_count value, but I would like to obtain posterior samples in the number_count variable. Is this somehow possible?

Thank you very much!

To enumerate, I think you need right truncated poisson (so that the support is finite). Then you can use Predictive with infer_discrete to get samples.

1 Like

Thanks a lot. I’ve implemented now the right-truncated Poisson distribution. Do you recommend the 1st or 2nd option of the model then?

def model(counts_data, p):
    
    sample_shape = counts_data.shape
    high = 40
            
    counts_average = numpyro.sample('counts_average', dist.Uniform(0, 5,))
    
    # first option 
    number_count = numpyro.sample('number_count', RightTruncatedPoisson(rate = counts_average, high = high), sample_shape = sample_shape)
    ###
    
    # second option
    right_truncated_poisson = RightTruncatedPoisson(rate = counts_average, high = high)
    number_count = right_truncated_poisson.sample(rng_key, sample_shape = sample_shape)
    ###
    
    numpyro.sample('data', dist.Binomial(total_count = number_count, probs = p), obs = counts_data)
        

our algorithms work for the first one, with plate instead of sample_shape

1 Like