Chinese Restaurant Process (mixture model) - not seeing mixture parameters in my param_store?

I’m trying to implement a Dirichlet process mixture model in Pyro using the Chinese restaurant process formulation. I’ve used Fritz’s rough code sketch from here and have been trying to combine it with parts from the PyMC3 tutorial on DPMMs. The dataset can be found here.

So far, I have this:

data = torch.tensor(df['sunspot.year'].values, dtype=torch.float32)
def crp_model(data):
    alpha0 = pyro.sample('alpha', dist.Gamma(1, 1))
    cluster_rates = {}
    crp_counts = [] 
    for i in range(len(data)):
        crp_weights = Variable(torch.tensor(crp_counts + [alpha0], dtype=torch.float32))
        crp_weights /= crp_weights.sum()

        zi = pyro.sample("z_{}".format(i), dist.Categorical(crp_weights))
        zi = zi.item() 

        if zi >= len(crp_counts):
            crp_counts.append(1)  # sit at a new table
        else:
            crp_counts[zi] += 1  # sit at an existing table

        if zi not in cluster_rates.keys():
            cluster_rates[zi] = pyro.sample("lambda_{}".format(zi), dist.Uniform(0, 20))
        lambda_i = cluster_rates[zi]
        pyro.sample("obs_{}".format(i), dist.Poisson(lambda_i), obs=data[i])

guide = AutoDelta(poutine.block(crp_model, hide=['z_{}'.format(i) for i in range(len(data))]))
optim = Adam({"lr": 0.01})
svi = SVI(crp_model, guide, optim, loss=TraceEnum_ELBO(), num_samples=1000)

def train(num_iterations):
    pyro.clear_param_store()
    for j in range(num_iterations):
        loss = svi.step(data)
        if j % 500 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))

train(5000)

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))

This gives me a set of 7 Poisson parameters (which is about what is expected, at least from cross-referencing the PyMC3 tutorial), but I have no way of accessing the crp_weights Variable that is defined in the model. Thus, I have no way of creating the scale mixture of Poissons that define my data, and thus I have no way of actually looking at the clustering that I’m learning.

Does anyone know how I might be able to track crp_weights as they evolve during inference? Since they’re not a Pyro param, they don’t end up in the Pyro param store :frowning:.

Also, I’d really appreciate it if someone could help me figure out how to make a plot of my posterior / scale mixture of Poissons after inference is complete. I haven’t really seen a straightforward way of accessing the posterior and computing a plot.

1 Like

Perhaps this was obvious – but this was solved by adding these lines:

weights = torch.tensor(crp_counts + [alpha0], dtype=torch.float32) 
weights /= weights.sum()
crp_weights = pyro.param("weights_{}".format(i), Variable(weights), constraint=constraints.simplex)

Now I am dealing with another issue.

When I define my own guide, or use an AutoGuide, such as in the following code, in the pyro_param_store there are no lambda_i variables, so I can’t construct my scale-mixture of Poisson distributions.

guide = AutoIAFNormal(poutine.block(crp_model, expose=['weights_{}'.format(i) for i in range(len(data))] + ['lambda_{}'.format(i) for i in range(30)]))

guide = AutoDiagonalNormal(poutine.block(crp_model, hide=['z_{}'.format(i) for i in range(len(data))]))

def guide(data):
    alpha_q = pyro.sample('alpha_q', dist.Gamma(2, 0.5))
    cluster_rates_q = {}  # sample this lazily
    crp_counts_q = []  # build this incrementally
    for i in range(len(data)):
        # sample from a CRP
        weights_q = torch.tensor(crp_counts_q + [alpha_q], dtype=torch.float32) 
        weights_q /= weights_q.sum()
        crp_weights_q = pyro.param("weights_q_{}".format(i), Variable(weights_q), constraint=constraints.simplex)

        zi_q = pyro.sample("z_q_{}".format(i), dist.Categorical(crp_weights_q))
        zi_q = zi_q.item() 

        if zi_q >= len(crp_counts_q):
            crp_counts_q.append(1)  # sit at a new table
        else:
            crp_counts_q[zi_q] += 1  # sit at an existing table

        if zi_q not in cluster_rates_q.keys():
            cluster_rates_q[zi_q] = pyro.sample("lambda_q_{}".format(zi_q), dist.Uniform(0, 200))
        lambda_i_q = cluster_rates_q[zi_q]
        pyro.sample("obs_q_{}".format(i), dist.Poisson(lambda_i_q), obs=data[i])

At least with the AutoDiagonalNormal guide you can examine parameters via the .median() method, I use this all the time:

median = guide.median(data)
print(median["lambda_0"])

Another way to extract those variables is to poutine.trace the guide and examine the sampled values:

from pyro import poutine
trace = poutine.trace(guide).get_trace(data)
print(trace.nodes["lambda_0"]["value"])