I’m trying to implement a Dirichlet process mixture model in Pyro using the Chinese restaurant process formulation. I’ve used Fritz’s rough code sketch from here and have been trying to combine it with parts from the PyMC3 tutorial on DPMMs. The dataset can be found here.
So far, I have this:
data = torch.tensor(df['sunspot.year'].values, dtype=torch.float32)
def crp_model(data):
alpha0 = pyro.sample('alpha', dist.Gamma(1, 1))
cluster_rates = {}
crp_counts = []
for i in range(len(data)):
crp_weights = Variable(torch.tensor(crp_counts + [alpha0], dtype=torch.float32))
crp_weights /= crp_weights.sum()
zi = pyro.sample("z_{}".format(i), dist.Categorical(crp_weights))
zi = zi.item()
if zi >= len(crp_counts):
crp_counts.append(1) # sit at a new table
else:
crp_counts[zi] += 1 # sit at an existing table
if zi not in cluster_rates.keys():
cluster_rates[zi] = pyro.sample("lambda_{}".format(zi), dist.Uniform(0, 20))
lambda_i = cluster_rates[zi]
pyro.sample("obs_{}".format(i), dist.Poisson(lambda_i), obs=data[i])
guide = AutoDelta(poutine.block(crp_model, hide=['z_{}'.format(i) for i in range(len(data))]))
optim = Adam({"lr": 0.01})
svi = SVI(crp_model, guide, optim, loss=TraceEnum_ELBO(), num_samples=1000)
def train(num_iterations):
pyro.clear_param_store()
for j in range(num_iterations):
loss = svi.step(data)
if j % 500 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))
train(5000)
for name, value in pyro.get_param_store().items():
print(name, pyro.param(name))
This gives me a set of 7 Poisson parameters (which is about what is expected, at least from cross-referencing the PyMC3 tutorial), but I have no way of accessing the crp_weights
Variable that is defined in the model. Thus, I have no way of creating the scale mixture of Poissons that define my data, and thus I have no way of actually looking at the clustering that I’m learning.
Does anyone know how I might be able to track crp_weights
as they evolve during inference? Since they’re not a Pyro param, they don’t end up in the Pyro param store .
Also, I’d really appreciate it if someone could help me figure out how to make a plot of my posterior / scale mixture of Poissons after inference is complete. I haven’t really seen a straightforward way of accessing the posterior and computing a plot.