Clustering MNIST using amortized inference

Hi there! I’m working on an amortized inference procedure for clustering MNIST data. My model and guide are largely guided by the “Dirichlet Process Mixture Models in Pyro” walkthrough.

The main difference in my code, however, is the usage of neural networks to learn tau and alpha which are the parameters of the MultivariateNormal for locs and the Dirichlet for weights, respecitvely. Here’s my code (tau_mlp and alpha_mlp are two neural networks that extend PyroModule):

def model(data, step):
    # sample mixture weights
    alpha = torch.ones(T)
    weights = pyro.sample('weights', Dirichlet(alpha))
    
    # sample mixture means
    with pyro.plate('components', T):
        locs = pyro.sample('locs', 
                           MultivariateNormal(torch.zeros(M),
                                              torch.eye(M)))
        
    # sample cluster assignments and observe
    with pyro.plate('data', size=N):
        assignments = pyro.sample('assignments',
                                  Categorical(weights))
        pyro.sample('obs',
                    MultivariateNormal(locs[assignments],
                                       torch.eye(M)),
                    obs=data)  # dim-1 b/c dim-0 is batch
        
def guide(data, step):
    pyro.module('alpha_mlp', alpha_mlp)
    pyro.module('tau_mlp', tau_mlp)

    if use_gpu: 
        data = data.cuda()
        
    # nn to approximate posterior for mixture means
    tau = tau_mlp(data.float())
    tau = tau.view(T,M)

    # sample mixture means
    with pyro.plate('components', T):
        locs = pyro.sample('locs', MultivariateNormal(tau, torch.eye(M)))
    
    # nn to approximate posterior for mixture weights 
    alpha = alpha_mlp(data.float()) # returns a vector of length T
    weights = pyro.sample('weights', Dirichlet(alpha))  # vector of length T
    
    # sample cluster assignments
    with pyro.plate('data', size=N):
        assignments = pyro.sample('assignments', Categorical(weights))


adam_params = {"lr": 0.001, "weight_decay": 0.01}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

for step in range(n_steps):
    elbo_est = svi.step(data, step)
    elbo_ests.append(elbo_est)

After running the inference procedure, the model didn’t seem to learn anything, the clustering being more or less random.

I’m not sure where I’m going wrong - is it how I’m using the neural networks in the guide? Or perhaps it is something else that I’m missing? Any help would be appreciated!

You might consider using enumeration and switching to TraceEnum_ELBO as in the Gaussian mixture model tutorial.