Hello everyone, I would like to build a simple model for my research, which is shown below:
def model(sc_ref, sum, X=None):
n_spot = len(X)
n_type = len(sc_ref)
with pyro.plate("spot", n_spot):
p = pyro.sample("Proportion", dist.Dirichlet(torch.full((n_type,), 1., device=device, dtype=torch.float64)))
mu = torch.matmul(p, sc_ref)
pyro.sample("Spatial RNA", dist.Multinomial(total_count=int(sum), probs=mu), obs=X)
As you can see, I have a multinomial distribution across the batch “spot”. I am just wondering if I want to have a different total_count in the multinomial distribution for different samples in the batch(plate), how can I achieve it? It seems that total_count must be an integer rather than a 1-dimensional tensor.
Here is my current solution. It is not only ugly but also very time-consuming. I will be appreciated if anyone has a better idea.
def model(sc_ref, sum, X=None):
n_spot = len(X)
n_type = len(sc_ref)
for i in pyro.plate("spot", n_spot):
p = pyro.sample("Proportion_{}".format(i), dist.Dirichlet(torch.full((n_type,), 1., device=device, dtype=torch.float64)))
mu = torch.matmul(p, sc_ref)
pyro.sample("Spatial RNA_{}".format(i), dist.Multinomial(total_count=int(sum[i]), probs=mu), obs=X[i])