Config enumerate over Multinomial distributions

Hi!

I’m new to Pyro, and I really enjoy this framework. I would like to do some modelling that involves a multinomial distribution. I’ve read the enumeration tutorial, and I decided to try it:

def model():
    multi = pyro.sample('multi', dist.Multinomial(2, torch.Tensor([1/N for i in range(N)])))
    
def guide():
    multi = pyro.sample('multi', dist.Multinomial(2, torch.Tensor([1/N for i in range(N)])))
    print(multi)
    
elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=0)
elbo.loss(model, pyro.infer.config_enumerate(guide, "parallel"))

The result shows that there was no iteration done:

tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 1.])

I believe it’s possible to solve the same problem using 2 Categorical distributions, but I was wondering if you plan to support enumeration for multinomial distributions (or if you want to have it, maybe then I could put in my 2 cents).

Thanks!

Maybe another question:

Is it possible to proide own marginalization, just like in Stan?

Hi @Owy,

Pyro currently does not support enumeration over Multinomial distributions. This was intentional because the support of Multinomial distributions grows exponentially in the number of categories, making enumeration infeasible in all but the smallest multinomial distributions. In those small cases I often hand-unroll the Multinomial to a Categorical.

If you really want to add support for Multinomial enumeration, you could add an .enumerate_support() method to the Multinomial class and add set the has_enumerate_support class variable to True:

class EnumerableMultinomial(Multinomial):
    has_enumerate_support = True
    def enumerate_support(expand=True):
        # TODO assert total_count is homogeneous; otherwise support size varies.
        raise NotImplementedError("TODO")

I believe you should be able to use this in your proposed guide.

Hi!

Thanks for the answer! That makes sense, I believe that at some point you had to choose between the consistency of API and the lack of exponential computations. I have a case where number of trials will be very small, therefore I will implement the class just like you suggested.