Thank you for the prompt response!
Yes, it does seem to be the same.
There is only one difference, which is how you initialize: whereas  puts in the emission prior of the first word in the sequence across all hidden states,  starts from jnp.zeros, but that shouldn’t have such a big impact in 2000 step sequence.
My code for  is exactly as per the tutorial with:
args = parser.parse_args("-n 2000 --num-words 10 --num-categories 3 --num-supervised 0 --num-unsupervised 2000".split(’ '))
My code for  is:
num_sequences, max_length = sequences.shape
transition_prior = jnp.ones(num_categories)
probs_x = numpyro.sample(
jnp.broadcast_to(transition_prior, (num_categories, num_categories))).to_event(1)
probs_y = numpyro.sample(
jnp.broadcast_to(emission_prior, (num_categories, num_words))).to_event(1)
def transition_fn(carry, y):
x_prev, t = carry #t isn't needed anymore as no masking is required
with numpyro.plate("sequences", num_sequences, dim=-2):
x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
y=numpyro.sample("y", dist.Categorical(probs_y[x]), obs=y)
return (x, t + 1), None
# this initialization is hard to mimic in model_1 in 
x_init = jnp.zeros((num_sequences,1), dtype=jnp.int32)
scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
Simulation of data is done with random.PRNGKey(1) and mcmc is done with random.PRNGKey(2)
The sampling speed of  is c. 10-20x slower - you can see from the progress bar how it struggles to explore the posterior with the steps count constantly jumping up and down.
References  and  refer to the tutorial links in the first post.
Some background: ran on CPU, AWS EC2 m5.xlarge
Numpyro installed from master (0.6.0), Jax (0.2.12)
 … is much better at the recovery of the true process
Please ignore this statement. When I ran it before, I was getting 3x divergences.
With the code that I posted above, I ultimately arrive at the same results - only slower (130 sec for  vs. 2,700 sec for )