The short tutorial is the same: the number of sampled Categorical variables has shape (num samples,)
rather than (num samples, num observations)
. Shouldn’t there be one sampled Categorical per observation?
def model(probs, locs):
c = numpyro.sample("c", numpyro.distributions.Categorical(probs))
numpyro.sample("x", numpyro.distributions.Normal(locs[c], 0.5))