Hi
Many thanks for all the support you provide on this. This is really great. (I wasn’t sure whether to post this as a github Issue or here, hope this is the right place).
I would be surprised if it’s a bug, most likely is my own mistake, but it does look weird. My issue seemingly has to do the sampling of torch’s categorical distribution. My model is essentially a mixture of regressions. So essentially I have within the model definition:
weights = pyro.sample("weights", dist.Dirichlet(0.5 * torch.ones(K)))
for j in pyro.plate("data_plate", NT):
z = pyro.sample(f"assignment_{j}", dist.Categorical(weights))
print(z)
mean = torch.matmul(beta[:,:,z],x[:,some_index[j]]) # pxd by dx1 = px1 , never mind this
obs = pyro.sample(f"obs_{j}", dist.Normal(mean, Sigma), obs=data[j,:])
And then I’m running a NUTS sampler on this.
In the snippet, z is the membership for sample j, which can be 0, 1 or 2. Thanks to my print(z), I see that z is assigned one of the 3 values at each sample as it should be (for example tensor(1)). But suddenly, for some reason (and this always happens at j==0 after it has already passed through all data points at least once), z gets assigned tensor([0, 1, 2]). Then obviously it breaks because torch.matmul is getting wrong dimensions.
I know that I could make this vectorized, but I’m trying to do it as simple as I can for debugging purposes. And also I’m not sure how to vectorize this using “with pyro.plate” so that z indexes the third dimension of a Torch array. But that’s another question!
Any clues?
Thanks again
Diego