Sampling Discrete Latent Discrete Variables in Pyro


I am very new to Pyro.

My model includes categorical discrete variables which I generated using the following code:-

X_Us1 = pyro.sample("X_Us1", dist.Categorical(torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])), obs=X_ls[:, 0])

The idea is to generate discrete variables that takes values from 0-4 with equal probabilities. This is matched with the observed value stored in the first column of “X_ls.” The code works fine for this scenario.

In a later stage, I intend to do imputation on a different dataset- where this variable is latent. I modified the above code as follows:-

X_Usl1 = pyro.sample("X_Usl1", dist.Categorical(torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])), obs=None)

The variables names are changed. Apart from that, the code remains the same. I have a series of eight such discrete variables (X_Usl1, …, X_Usl8). However, when I run the model, I run into issue. These latent tensors seems to have different dimensions.

X_Usl1- 5 1
X_Usl2- 5 1 1
X_Usl8- 5 1 1 1 1 1 1 1 1

Could you please tell me why I am running into this issue for the discrete variables? Also, is there a way to overcome this issue for latent variables?

Thanks in advance.

A MWE (minimum working example) if possible would be very helpful here.

Also here’s a toy model that might help you.

Hi @vishnu_baburajan I suspect your tensors are getting different dimensions due to enumeration. I’d recommend taking a look at the Tensor Shapes Tutorial.

Hi @fritzo thank you for your support. I wanted to read more and try implementing it before replying. Which caused this delay.

I have solved this partially.

def model1(alpha_lsp, gamma_lsp, alpha_oep, gamma_oep, y_att_ls1, y_att_oe1, len_lk1, len_oe1, obs=None):
with pyro.plate("data_oep", len_oe1):
    X_Us1 = pyro.sample("X_Us1", dist.OneHotCategorical(torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])), obs=None)
    X_Us2 = pyro.sample("X_Us2", dist.OneHotCategorical(torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])), obs=None)
    X_Us8 = pyro.sample("X_Us8", dist.OneHotCategorical(torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2])), obs=None)

X_Usl1 = torch.squeeze(X_Us1)
X_Usl2 = torch.squeeze(X_Us2)
X_Usl8 = torch.squeeze(X_Us8)

X_Var_ls =, X_Usl2, X_Usl3, X_Usl4, X_Usl5, X_Atl1, X_Atl2, X_Atl3), 1)

Doing the squeeze allowed me to overcome the issue I mentioned previously. However, I have a new issue. Immediately after this, I am running the following set of commands. The idea is to do an imputation on the X_Usl1, X_Usl2, …, X_Usl8 variables.

with pyro.plate("data_oepy", len_oe1):
    y_att_oepr = pyro.sample("y_att_oepr", dist.Normal(alpha_lsp + torch.matmul(X_Var_lsp, torch.tensor(gamma_lsp)), 0.1), obs=y_att_oe1)
return y_att_oepr

The produces the following error:

ValueError: Shape mismatch inside plate(‘data_oepy’) at site y_att_oepr dim -1, 1931 vs 5

I had printed the shapes of

  • X_Var_lsp

  • gamma_lsp

This is the output I am getting:-
Shape of X_Usl1: torch.Size([1931, 5])
Shape of X_Usl2: torch.Size([1931, 5])
Shape of X_Usl3: torch.Size([1931, 5])
Shape of X_Usl4: torch.Size([1931, 5])
Shape of X_Var_lsp: torch.Size([1931, 40])
Length of len_oe1: 1931
Shape of X_Usl1: torch.Size([5, 5])
Shape of X_Usl2: torch.Size([5, 5])
Shape of X_Usl3: torch.Size([5, 5])
Shape of X_Usl4: torch.Size([5, 5])
Shape of X_Var_lsp: torch.Size([5, 40])
Length of len_oe1: 1931

Please note that the size of the plate is constant 1931. Also, during the first set of print statements, the size is what is expected ([1931, 5]). However, in the subsequent print statements, the size changes to ([5, 5]). This is causing the difference in dimensions.

Could you please help me resolve this issue?

Thanks in advance.