Shape mismatch

Hi,

I am very new to SVI and Pyro in general. I am trying to implement a simple model which could be seen as an extension of the gaussian mixture example. The setup of the problem is as follows:
We have N individuals who each belong to a single class z_i \in [L] for i = 1,2,\dots,N. There are J questions which have parameter matrix \beta \in \mathbb{R}^{J \times M} and each class has a parameter matrix \Delta_l \in \mathbb{R}^{J \times M} for l \in [L], so \Delta \in \mathbb{R}^{L \times J \times M}. We have the response matrix Y \in \mathbb{R}^{N \times J}, where
P(Y_{ij} = 1 | z_i) = \sigma(\beta_j^{\top} \Delta_{z_i, j}).
We assume knowledge of \Delta and \beta, observe Y, and would like to estimate z. Here is my attempt:

@config_enumerate
def model(Y):
    d0 = torch.ones(L)
    pi = pyro.sample("pi", dist.Dirichlet(d0))
    with pyro.plate("assignments", N):
        z = pyro.sample("z", dist.Categorical(torch.ones(L)), infer={"enumerate": "parallel"})
        pyro.sample("Y", dist.Bernoulli(torch.sigmoid(torch.sum((delta_mat[z] * beta_mat), axis = 2))),
                        obs = Y)

In my tests, I have N = 30, J = 10, L = 4, M = 4. I get the following error:

ValueError: Shape mismatch inside plate('assignments') at site Y dim -1, 30 vs 4
   Trace Shapes:         
    Param Sites:         
   Sample Sites:         
         pi dist      | 4
           value      | 4
assignments dist      |  
           value   30 |  
          z dist   30 |  
           value 4  1 |  

I am a bit confused about the plate dimension in this case. What is the possible fix for this issue? Thanks in advance.

Hi,

You need an additional plate for J. Since your Y has the shape NxJ and plate dims need to match your data I would following have two plates: pyro.plate("individuals", N, dim=-2) and pyro.plate("questions", J, dim=-1). Then it would look something like this:

    with pyro.plate("individuals", N, dim=-2):
        z = pyro.sample("z", dist.Categorical(torch.ones(L)), infer={"enumerate": "parallel"})
        with pyro.plate("questions", J, dim=-1):
            pyro.sample("Y", dist.Bernoulli(torch.sigmoid(torch.sum((delta_mat[z] * beta_mat), axis = 2))),
                        obs = Y)

You also need to make sure that the bernoulli probability (torch.sigmoid(torch.sum((delta_mat[z] * beta_mat), axis = 2)) need to have the shape LxNxJ: L is for enumerated z, N is for individuals plate, and J is for questions plate.

1 Like

Thank you for your clarification about the shape of the bernoulli probability tensor. So, if I understand correctly, I would have to enumerate through my observation matrix Y, similar to the example with the time series:
Inference with Discrete Latent Variables — Pyro Tutorials 1.9.0 documentation.

I also had an attempt without the enumeration:

# assume oracle beta
@config_enumerate
def model(Y):
    d0 = torch.ones(L)
    pi = pyro.sample("pi", dist.Dirichlet(d0))

    with pyro.plate("individuals", N, dim=-2):
        z = pyro.sample("z", dist.Categorical(torch.ones(L)), infer={"enumerate": "parallel"})
        with pyro.plate("questions", J, dim=-1):
            # pyro.sample("Y", dist.Bernoulli(torch.sigmoid(torch.sum((delta_mat[z] * beta_mat), axis = 2))),
            #             obs = Y)
            prob = torch.sigmoid(torch.sum(torch.squeeze(delta_mat[z]) * beta_mat, axis = 2))
            pyro.sample("Y", dist.Bernoulli(prob), obs = Y)


def guide(Y):
    d = pyro.param("d0", torch.ones(L)/L, constraint = constraints.positive)
    pi = pyro.sample("pi", dist.Dirichlet(d))

    with pyro.plate("assignments,", N):
        z_prob = pyro.param(
            "z_prob",
            torch.ones(N, L)/ L,
            constraint = constraints.simplex,
        )
        pyro.sample("z", dist.Categorical(z_prob))

However, I encountered this error:

ValueError: Model and guide shapes disagree at site 'z': torch.Size([30, 1]) vs torch.Size([30])

I tried printing shape of both distributions and both are [30]. Do you know what could be causing this issue? Thanks!

Reshape it to the shape (30,1) and also add dim=-2 to the plate in the guide. Model and guide plates need to match

Also if you are marginalizing out z in the model then you don’t need it in the guide (i.e. you can remove z samples in the guide).