def hmm_model(K, N, observations, concentration, mu_loc, mu_scale, sigma_shape, sigma_scale):
X = {}
Y = {}
theta = pyro.sample("theta", dist.Dirichlet(torch.ones((K, K)) * concentration / K))
mus = pyro.sample(
"mus", dist.Normal(torch.ones(K) * mu_loc, torch.ones(K) * mu_scale)
)
sigmas = pyro.sample(
"sigmas", dist.Normal(torch.ones(K) * sigma_shape, torch.ones(K) * sigma_scale)
)
X[0] = pyro.sample(
"X[0]", dist.Categorical(tensor([1.0] + [0.0] * (K - 1))), obs=tensor(0)
)
Y[0] = pyro.sample(
"Y[0]", dist.Normal(mus[X[0]], sigmas[X[0]]), obs=observations[0]
)
for i in pyro.markov(range(1, N)):
X[i] = pyro.sample(
"X[" + str(i) + "]",
dist.Categorical(theta[X[i - 1]]),
infer={"enumerate": "parallel"},
)
Y[i] = pyro.sample(
"Y[" + str(i) + "]",
dist.Normal(mus[X[i]], sigmas[X[i]]),
obs=observations[i],
)
Hi, I am trying to run a simple HMM model in Pyro with NUTS and enumeration for discrete variables. However, I am running into an issue with the shape of theta (transition matrix). Any ideas what the error means?