Issue with mixture models


Many thanks for all the support you provide on this. This is really great. (I wasn’t sure whether to post this as a github Issue or here, hope this is the right place).

I would be surprised if it’s a bug, most likely is my own mistake, but it does look weird. My issue seemingly has to do the sampling of torch’s categorical distribution. My model is essentially a mixture of regressions. So essentially I have within the model definition:

    weights = pyro.sample("weights", dist.Dirichlet(0.5 * torch.ones(K)))

    for j in pyro.plate("data_plate", NT):
        z = pyro.sample(f"assignment_{j}", dist.Categorical(weights))
        mean = torch.matmul(beta[:,:,z],x[:,some_index[j]]) # pxd by dx1 = px1 , never mind this
        obs = pyro.sample(f"obs_{j}", dist.Normal(mean, Sigma), obs=data[j,:])

And then I’m running a NUTS sampler on this.

In the snippet, z is the membership for sample j, which can be 0, 1 or 2. Thanks to my print(z), I see that z is assigned one of the 3 values at each sample as it should be (for example tensor(1)). But suddenly, for some reason (and this always happens at j==0 after it has already passed through all data points at least once), z gets assigned tensor([0, 1, 2]). Then obviously it breaks because torch.matmul is getting wrong dimensions.

I know that I could make this vectorized, but I’m trying to do it as simple as I can for debugging purposes. And also I’m not sure how to vectorize this using “with pyro.plate” so that z indexes the third dimension of a Torch array. But that’s another question!

Any clues?

Thanks again

I just want to add that I think it was a mistake using a Normal instead of a MultivariateNormal, so the snippet would now look

    for j in pyro.plate("data_plate", NT):
        z = pyro.sample(f"assignment_{j}", dist.Categorical(weights))
        b = torch.reshape(beta[:,z],(p,d))
        mu = b @ tmat[:,time[j]]
        pyro.sample("obs", dist.MultivariateNormal(mu, Sigma), obs=X)

And here’s an example of the sampling of z (at the end it samples wrongly):



But it’s still complaining

Hey, I don’t have a solution to this but wondering if you found one. I’m running into a similar problem and hoping I could learn from anything you’ve found out. Thanks!