# Multinomial distribution with different total_count values in a plate

Hello everyone, I would like to build a simple model for my research, which is shown below:

``````def model(sc_ref, sum, X=None):
n_spot = len(X)
n_type = len(sc_ref)
with pyro.plate("spot", n_spot):
p = pyro.sample("Proportion", dist.Dirichlet(torch.full((n_type,), 1., device=device, dtype=torch.float64)))
mu = torch.matmul(p, sc_ref)
pyro.sample("Spatial RNA", dist.Multinomial(total_count=int(sum), probs=mu), obs=X)
``````

As you can see, I have a multinomial distribution across the batch “spot”. I am just wondering if I want to have a different total_count in the multinomial distribution for different samples in the batch(plate), how can I achieve it? It seems that total_count must be an integer rather than a 1-dimensional tensor.

Here is my current solution. It is not only ugly but also very time-consuming. I will be appreciated if anyone has a better idea.

``````def model(sc_ref, sum, X=None):
n_spot = len(X)
n_type = len(sc_ref)
for i in pyro.plate("spot", n_spot):
p = pyro.sample("Proportion_{}".format(i), dist.Dirichlet(torch.full((n_type,), 1., device=device, dtype=torch.float64)))
mu = torch.matmul(p, sc_ref)
pyro.sample("Spatial RNA_{}".format(i), dist.Multinomial(total_count=int(sum[i]), probs=mu), obs=X[i])
``````

When calculating the log probability `Multinomial` distribution ignores total counts argument (it calculates it from the values). Since you are not sampling from the `Multinomial` distribution in your model you can safely just set `total_count=1`.

``````import torch
import torch.distributions as dist

m1 = dist.Multinomial(1, torch.tensor([1.0, 1.0, 1.0, 1.0]))
m2 = dist.Multinomial(10, torch.tensor([1.0, 1.0, 1.0, 1.0]))
x = torch.tensor([2, 5, 2, 1])
expected = m1.log_prob(x)
actual = m2.log_prob(x)
assert expected == actual
``````
1 Like

Thank you for the answer. That makes a lot of sense.