Consider the following model, which is basically a GMM where the cluster centers are themselves random variables. (I’m mostly interested in it to try and understand plating/enumeration better). I notice that during enumeration, if we sample over a variable with shape (2,1,1) inside a plate with length 2000 that’s using dim=-2, the result is a sample with shape (2000, 1). I don’t understand this - it seems like the shape ought to be (2, 2000, 1).
@config_enumerate(default="parallel")
def model():
weights = pyro.sample(
'weights', dist.Dirichlet(1./K * torch.ones(K)))
with pyro.plate('components', K):
gamma_1 = pyro.sample('gamma_1', dist.Normal(1, 1))
sigma_gamma_1 = pyro.sample('sigma_gamma_1', dist.InverseGamma(1, 1))
with pyro.plate('data', 2000, dim=-2):
# normally, assignment is (2000, 1)
# during enumeration, assignment is (2, 1, 1)
assignment = pyro.sample('assignment', dist.Categorical(weights))
# normally, alpha_11 and alpha_22 are (2000, 1)
# during enumeration, alpha_11 and alpha_12 are (2, 1, 1)
alpha_11 = Vindex(gamma_1)[assignment]
alpha_12 = Vindex(sigma_gamma_1)[assignment]
# theta_1 is always (2000, 1), but it seems like during enumeration
# theta_1 should be (2, 2000, 1)
theta_1 = pyro.sample('theta_1', dist.Normal(alpha_11, alpha_12))
print(theta_1.shape)
optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO()
guide = AutoDiagonalNormal(poutine.block(model, hide=['assignment']))
svi = SVI(model, guide, optim, loss=elbo)
svi.loss(model, guide)
If I repeat the experiment with distribution parameters that are not latent variables, I get the expected result.
with pyro.plate('s1', 100, dim=-2):
s2 = pyro.sample('s2', dist.Normal(torch.zeros(2,1,1), torch.ones(2,1,1)))
assert (s2.numpy().shape == (2, 100, 1))
I think I am missing something - is there something going on that I ought to know about? Thanks in advance.