Thank you for the prompt response!
Yes, it does seem to be the same.
There is only one difference, which is how you initialize: whereas [1] puts in the emission prior of the first word in the sequence across all hidden states, [2] starts from jnp.zeros, but that shouldn’t have such a big impact in 2000 step sequence.
My code for [1] is exactly as per the tutorial with:
args = parser.parse_args(“-n 2000 --num-words 10 --num-categories 3 --num-supervised 0 --num-unsupervised 2000”.split(’ '))
My code for [2] is:
def model_1_alt(sequences):
num_sequences, max_length = sequences.shape
num_categories=args.num_categories
num_words=args.num_words
emission_prior=jnp.repeat(0.1, num_words)
transition_prior = jnp.ones(num_categories)
probs_x = numpyro.sample(
"probs_x", dist.Dirichlet(
jnp.broadcast_to(transition_prior, (num_categories, num_categories))).to_event(1)
)
probs_y = numpyro.sample(
"probs_y",dist.Dirichlet(
jnp.broadcast_to(emission_prior, (num_categories, num_words))).to_event(1)
)
def transition_fn(carry, y):
x_prev, t = carry #t isn't needed anymore as no masking is required
with numpyro.plate("sequences", num_sequences, dim=-2):
x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))
y=numpyro.sample("y", dist.Categorical(probs_y[x]), obs=y)
return (x, t + 1), None
# this initialization is hard to mimic in model_1 in [1]
x_init = jnp.zeros((num_sequences,1), dtype=jnp.int32)
scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
Simulation of data is done with random.PRNGKey(1) and mcmc is done with random.PRNGKey(2)
The sampling speed of [2] is c. 10-20x slower - you can see from the progress bar how it struggles to explore the posterior with the steps count constantly jumping up and down.
References [1] and [2] refer to the tutorial links in the first post.
Some background: ran on CPU, AWS EC2 m5.xlarge
Numpyro installed from master (0.6.0), Jax (0.2.12)
EDIT:
[1] … is much better at the recovery of the true process
Please ignore this statement. When I ran it before, I was getting 3x divergences.
With the code that I posted above, I ultimately arrive at the same results - only slower (130 sec for [1] vs. 2,700 sec for [2])