Understanding the behavior of enumeration within a plate

Consider the following model, which is basically a GMM where the cluster centers are themselves random variables. (I’m mostly interested in it to try and understand plating/enumeration better). I notice that during enumeration, if we sample over a variable with shape (2,1,1) inside a plate with length 2000 that’s using dim=-2, the result is a sample with shape (2000, 1). I don’t understand this - it seems like the shape ought to be (2, 2000, 1).

@config_enumerate(default="parallel")
def model():
    weights = pyro.sample(
        'weights', dist.Dirichlet(1./K * torch.ones(K)))
    
    with pyro.plate('components', K):
        gamma_1 = pyro.sample('gamma_1', dist.Normal(1, 1))
        sigma_gamma_1 = pyro.sample('sigma_gamma_1', dist.InverseGamma(1, 1))
                        
    with pyro.plate('data', 2000, dim=-2):
        # normally, assignment is (2000, 1)
        # during enumeration, assignment is (2, 1, 1)
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        
        # normally, alpha_11 and alpha_22 are (2000, 1)
        # during enumeration, alpha_11 and alpha_12 are (2, 1, 1)
        alpha_11 = Vindex(gamma_1)[assignment]
        alpha_12 = Vindex(sigma_gamma_1)[assignment]
        
        # theta_1 is always (2000, 1), but it seems like during enumeration
        # theta_1 should be (2, 2000, 1)
        theta_1 = pyro.sample('theta_1', dist.Normal(alpha_11, alpha_12))
        print(theta_1.shape)

optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO()
guide = AutoDiagonalNormal(poutine.block(model, hide=['assignment']))
svi = SVI(model, guide, optim, loss=elbo)
svi.loss(model, guide)

If I repeat the experiment with distribution parameters that are not latent variables, I get the expected result.

with pyro.plate('s1', 100, dim=-2):
    s2 = pyro.sample('s2', dist.Normal(torch.zeros(2,1,1), torch.ones(2,1,1)))
    
assert (s2.numpy().shape == (2, 100, 1))

I think I am missing something - is there something going on that I ought to know about? Thanks in advance.

Hi @ethansargent, this behavior is caused by your use of a guide for theta_1 - the samples in the guide don’t depend on enumerated variables (assignment, in this case), so they will not have the extra batch dimension introduced by enumeration.

1 Like

Thanks!