Multinomial distribution with different total_count values in a plate

Hello everyone, I would like to build a simple model for my research, which is shown below:

def model(sc_ref, sum, X=None):
    n_spot = len(X)
    n_type = len(sc_ref)
    with pyro.plate("spot", n_spot):
        p = pyro.sample("Proportion", dist.Dirichlet(torch.full((n_type,), 1., device=device, dtype=torch.float64)))
        mu = torch.matmul(p, sc_ref)
        pyro.sample("Spatial RNA", dist.Multinomial(total_count=int(sum), probs=mu), obs=X)

As you can see, I have a multinomial distribution across the batch “spot”. I am just wondering if I want to have a different total_count in the multinomial distribution for different samples in the batch(plate), how can I achieve it? It seems that total_count must be an integer rather than a 1-dimensional tensor.

Here is my current solution. It is not only ugly but also very time-consuming. I will be appreciated if anyone has a better idea.

def model(sc_ref, sum, X=None):
    n_spot = len(X)
    n_type = len(sc_ref)
    for i in pyro.plate("spot", n_spot):
        p = pyro.sample("Proportion_{}".format(i), dist.Dirichlet(torch.full((n_type,), 1., device=device, dtype=torch.float64)))
        mu = torch.matmul(p, sc_ref)
        pyro.sample("Spatial RNA_{}".format(i), dist.Multinomial(total_count=int(sum[i]), probs=mu), obs=X[i])

When calculating the log probability Multinomial distribution ignores total counts argument (it calculates it from the values). Since you are not sampling from the Multinomial distribution in your model you can safely just set total_count=1.

import torch
import torch.distributions as dist

m1 = dist.Multinomial(1, torch.tensor([1.0, 1.0, 1.0, 1.0]))
m2 = dist.Multinomial(10, torch.tensor([1.0, 1.0, 1.0, 1.0]))
x = torch.tensor([2, 5, 2, 1])
expected = m1.log_prob(x)
actual = m2.log_prob(x)
assert expected == actual
1 Like

Thank you for the answer. That makes a lot of sense.