Hi there! I’m working on an amortized inference procedure for clustering MNIST data. My model and guide are largely guided by the “Dirichlet Process Mixture Models in Pyro” walkthrough.

The main difference in my code, however, is the usage of neural networks to learn tau and alpha which are the parameters of the MultivariateNormal for locs and the Dirichlet for weights, respecitvely. Here’s my code (tau_mlp and alpha_mlp are two neural networks that extend PyroModule):

```
def model(data, step):
# sample mixture weights
alpha = torch.ones(T)
weights = pyro.sample('weights', Dirichlet(alpha))
# sample mixture means
with pyro.plate('components', T):
locs = pyro.sample('locs',
MultivariateNormal(torch.zeros(M),
torch.eye(M)))
# sample cluster assignments and observe
with pyro.plate('data', size=N):
assignments = pyro.sample('assignments',
Categorical(weights))
pyro.sample('obs',
MultivariateNormal(locs[assignments],
torch.eye(M)),
obs=data) # dim-1 b/c dim-0 is batch
def guide(data, step):
pyro.module('alpha_mlp', alpha_mlp)
pyro.module('tau_mlp', tau_mlp)
if use_gpu:
data = data.cuda()
# nn to approximate posterior for mixture means
tau = tau_mlp(data.float())
tau = tau.view(T,M)
# sample mixture means
with pyro.plate('components', T):
locs = pyro.sample('locs', MultivariateNormal(tau, torch.eye(M)))
# nn to approximate posterior for mixture weights
alpha = alpha_mlp(data.float()) # returns a vector of length T
weights = pyro.sample('weights', Dirichlet(alpha)) # vector of length T
# sample cluster assignments
with pyro.plate('data', size=N):
assignments = pyro.sample('assignments', Categorical(weights))
adam_params = {"lr": 0.001, "weight_decay": 0.01}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
for step in range(n_steps):
elbo_est = svi.step(data, step)
elbo_ests.append(elbo_est)
```

After running the inference procedure, the model didn’t seem to learn anything, the clustering being more or less random.

I’m not sure where I’m going wrong - is it how I’m using the neural networks in the guide? Or perhaps it is something else that I’m missing? Any help would be appreciated!