Intuition for the difference between the two HMM tutorials (forward algo. vs marginalization)

Hi everyone,

I’ve been studying the following two HMM examples in the Numpyro docs (thank you for them btw!):

  1. HMM - leveraging the forward algorithm
  2. HMM - marginalizing out the latent discrete vars

To my naive eye, they both seem fairly similar in their implementation (scanning along the sequence to derive the loglik).
I’ve done a little benchmark on the same generating model (Categorical observed ala num_categories-num_words in [1]; 2000 steps sequence with latent states, 3x10 shape of emission probability matrix) and I was amazed by the large difference in performance.

[1] seems to run an order of magnitude faster and is much better at the recovery of the true process.

I’m probably missing something obvious. Do you have some intuition for the big difference?

Here is a comparison based on what I know.

  • Algorithm: both are using forward algorithm; the underlying math is intended to be the same
  • Compiling time: [1] will be compiled faster
  • Running time on CPU: both should be the same; probably [1] will be a little bit faster
  • Running time on GPU: [2] will be faster, especially the time dimension is large ([2] uses a parallel-scan algorithm)
  • Recovering the latent states: [1] will require some math to implement the backward algorithm while inferring the latent states will be supported for [2] in an upcoming release

[1] … is much better at the recovery of the true process

The underlying math is intended to be the same. Probably there is a bug somewhere. Could you share your code so I can take a look?

Thank you for the prompt response!

Yes, it does seem to be the same.
There is only one difference, which is how you initialize: whereas [1] puts in the emission prior of the first word in the sequence across all hidden states, [2] starts from jnp.zeros, but that shouldn’t have such a big impact in 2000 step sequence.

My code for [1] is exactly as per the tutorial with:

args = parser.parse_args("-n 2000 --num-words 10 --num-categories 3 --num-supervised 0 --num-unsupervised 2000".split(’ '))

My code for [2] is:

def model_1_alt(sequences):
     num_sequences, max_length = sequences.shape
     num_categories=args.num_categories
     num_words=args.num_words

     emission_prior=jnp.repeat(0.1, num_words)
     transition_prior = jnp.ones(num_categories)

     probs_x = numpyro.sample(
         "probs_x", dist.Dirichlet(
         jnp.broadcast_to(transition_prior, (num_categories, num_categories))).to_event(1)
     )
     probs_y = numpyro.sample(
         "probs_y",dist.Dirichlet(
         jnp.broadcast_to(emission_prior, (num_categories, num_words))).to_event(1)
     )  
 
    def transition_fn(carry, y):
         x_prev, t = carry #t isn't needed anymore as no masking is required
         with numpyro.plate("sequences", num_sequences, dim=-2):
         x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))  
         y=numpyro.sample("y", dist.Categorical(probs_y[x]), obs=y)
         return (x, t + 1), None

    # this initialization is hard to mimic in model_1 in [1]
    x_init = jnp.zeros((num_sequences,1), dtype=jnp.int32)
    scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))

Simulation of data is done with random.PRNGKey(1) and mcmc is done with random.PRNGKey(2)

The sampling speed of [2] is c. 10-20x slower - you can see from the progress bar how it struggles to explore the posterior with the steps count constantly jumping up and down.

References [1] and [2] refer to the tutorial links in the first post.
Some background: ran on CPU, AWS EC2 m5.xlarge
Numpyro installed from master (0.6.0), Jax (0.2.12)

EDIT:

[1] … is much better at the recovery of the true process

Please ignore this statement. When I ran it before, I was getting 3x divergences.
With the code that I posted above, I ultimately arrive at the same results - only slower (130 sec for [1] vs. 2,700 sec for [2])

1 Like

please also keep in mind that two pieces of code that map onto equivalent math will not in general be equivalent when run on a computer with finite precision. i’d generally expect the parallel scan version to be have better numerical properties in terms of underflow/overflow/etc so it might also perform better in practice (although any differences would probably go away in practice if you used e.g. 256-bit precision)

1 Like

130 sec for [1] vs. 2,700 sec for [2]

I am not sure what is squences in your code. I guess it is unsupervised_words.reshape((1, -1)), i.e. num_sequences=1? If so, this seems like an important performance issue. Could you make a github thread for this?

If num_sequences=1, then you can simplify the code as follows

    def transition_fn(carry, y):
         x_prev = carry
         x = numpyro.sample("x", dist.Categorical(probs_x[x_prev]))  
         y = numpyro.sample("y", dist.Categorical(probs_y[x]), obs=y)
         return x, None

    x_init = 0
    scan(transition_fn, (x_init, 0), unsupervised_words)

I am not sure what is squences in your code. I guess it is unsupervised_words.reshape((1, -1)), i.e. num_sequences=1 ?

Apologies! You’re correct. That’s why I’ve included the argparse statement - I use the same simulated data from example [1], for [2] I just needed to shape it.

If num_sequences=1, then you can simplify the code as follows

Thank you. I was keeping the original version to keep as much resemblance to the tutorial as possible to eliminate any error on my side, but I’ll use this version for the Github issue.

Could you make a github thread for this?

I’ll do that tonight!

Thank you for your help!