In the HMM tutorial, model_0 is defined as follows:
def model_0(sequences, lengths, args, batch_size=None, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
with poutine.mask(mask=include_prior):
probs_x = pyro.sample("probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
probs_y = pyro.sample("probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2))
tones_plate = pyro.plate("tones", data_dim, dim=-1) # tones_plate needs to be defined outside an iterator.
for i in pyro.plate("sequences", num_sequences, batch_size):
length = lengths[i]
sequence = sequences[i, :length]
x = 0
for t in pyro.markov(range(length)):
x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]), infer={"enumerate": "parallel"})
with tones_plate:
pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequence[t])
My question is why do we need to set max_plate_nesting = 1 for model_0?
elbo = TraceTMC_ELBO(max_plate_nesting=1 if model is model_0 else 2)
Isn’t tones_plate nested inside the sequences plate? I tried setting max_plate_nesting = 2 for model_0, and the program works fine. Thank you so much for your kind help!