I am very new to SVI and Pyro in general. I am trying to implement a simple model which could be seen as an extension of the gaussian mixture example. The setup of the problem is as follows:
We have N individuals who each belong to a single class z_i \in [L] for i = 1,2,\dots,N. There are J questions which have parameter matrix \beta \in \mathbb{R}^{J \times M} and each class has a parameter matrix \Delta_l \in \mathbb{R}^{J \times M} for l \in [L], so \Delta \in \mathbb{R}^{L \times J \times M}. We have the response matrix Y \in \mathbb{R}^{N \times J}, where P(Y_{ij} = 1 | z_i) = \sigma(\beta_j^{\top} \Delta_{z_i, j}).
We assume knowledge of \Delta and \beta, observe Y, and would like to estimate z. Here is my attempt:
@config_enumerate
def model(Y):
d0 = torch.ones(L)
pi = pyro.sample("pi", dist.Dirichlet(d0))
with pyro.plate("assignments", N):
z = pyro.sample("z", dist.Categorical(torch.ones(L)), infer={"enumerate": "parallel"})
pyro.sample("Y", dist.Bernoulli(torch.sigmoid(torch.sum((delta_mat[z] * beta_mat), axis = 2))),
obs = Y)
In my tests, I have N = 30, J = 10, L = 4, M = 4. I get the following error:
ValueError: Shape mismatch inside plate('assignments') at site Y dim -1, 30 vs 4
Trace Shapes:
Param Sites:
Sample Sites:
pi dist | 4
value | 4
assignments dist |
value 30 |
z dist 30 |
value 4 1 |
I am a bit confused about the plate dimension in this case. What is the possible fix for this issue? Thanks in advance.
You need an additional plate for J. Since your Y has the shape NxJ and plate dims need to match your data I would following have two plates: pyro.plate("individuals", N, dim=-2) and pyro.plate("questions", J, dim=-1). Then it would look something like this:
with pyro.plate("individuals", N, dim=-2):
z = pyro.sample("z", dist.Categorical(torch.ones(L)), infer={"enumerate": "parallel"})
with pyro.plate("questions", J, dim=-1):
pyro.sample("Y", dist.Bernoulli(torch.sigmoid(torch.sum((delta_mat[z] * beta_mat), axis = 2))),
obs = Y)
You also need to make sure that the bernoulli probability (torch.sigmoid(torch.sum((delta_mat[z] * beta_mat), axis = 2)) need to have the shape LxNxJ: L is for enumerated z, N is for individuals plate, and J is for questions plate.