Using SVI for enumerated HMM

I am trying to recreate the enumerate HMM numpyro example with SVI instead of MCMC but am getting the following error.

ValueError: Continuous inference cannot handle discrete sample site 'x'.

Here is my model (taken from the online example)

#     x[t-1] --> x[t] --> x[t+1]
#        |        |         |
#        V        V         V
#     y[t-1]     y[t]     y[t+1]
#
# This model includes a plate for the data_dim = 44 keys on the piano. This
# model has two "style" parameters probs_x and probs_y that we'll draw from a
# prior. The latent state is x, and the observed state is y.
def model_1(sequences, lengths, hidden_dim, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    with mask(mask=include_prior):
        probs_x = numpyro.sample(
            "probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
        )
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2),
        )

    def transition_fn(carry, y):
        x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                x = numpyro.sample(
                    "x",
                    dist.Categorical(probs_x[x_prev]),
                    infer={"enumerate": "parallel"},
                )
                with numpyro.plate("tones", data_dim, dim=-1):
                    numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y)
        return (x, t + 1), None

    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    # NB swapaxes: we move time dimension of `sequences` to the front to scan over it
    scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))

And code for running SVI on said model

def main(num_samples=1000, hidden_dim=16, truncate=None, num_sequences=None,
         kernel='nuts', num_warmup=500, num_chains=1, device='cpu'):

    model = model_1
    
    numpyro.set_platform(device)
    numpyro.set_host_device_count(num_chains)

    _, fetch = load_dataset(JSB_CHORALES, split="train", shuffle=False)
    lengths, sequences = fetch()
    if num_sequences:
        sequences = sequences[0 : num_sequences]
        lengths = lengths[0 : num_sequences]

    logger.info("-" * 40)
    logger.info("Training {} on {} sequences".format(model.__name__, len(sequences)))

    # find all the notes that are present at least once in the training set
    present_notes = (sequences == 1).sum(0).sum(0) > 0
    # remove notes that are never played (we remove 37/88 notes with default args)
    sequences = sequences[:, :, present_notes]

    if truncate:
        lengths = lengths.clip(0, truncate)
        sequences = sequences[:, : truncate]

    logger.info("Each sequence has shape {}".format(sequences[0].shape))
    logger.info("Starting inference...")
    rng_key = random.PRNGKey(2)
    start = time.time()
#     kernel = {"nuts": NUTS, "hmc": HMC}[kernel](model)
#     mcmc = MCMC(
#         kernel,
#         num_warmup=num_warmup,
#         num_samples=num_samples,
#         num_chains=num_chains,
#         progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
#     )
#     mcmc.run(rng_key, sequences, lengths, hidden_dim)
#     mcmc.print_summary()
#     logger.info("\nMCMC elapsed time: {}".format(time.time() - start))

    optim = Adam({'lr': 0.01, 'betas': [0.8, 0.99]})
    elbo = Trace_ELBO(rng_key)

    guide = AutoDelta(
        poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))
    )

    svi = SVI(model, guide, optim, loss=elbo)
    start = time.time()
    svi_result = svi.run(rng_key, 25, sequences, lengths, hidden_dim)
    logger.info("\nSVI elapsed time: {}".format(time.time() - start))

The model runs (albeit slowly) when I use MCMC and I’ve gotten enumerated HMM’s to work with SVI using pyro. Is there something wrong with my code or is SVI + discrete HMM combination not supported in numpyro? My ultimate goal of exploring numpyro was to speed up runtime through using jax’s scan() as its documentation makes me think it would be more efficient than writing a python for-loop over each time point in a pyro HMM.

Any help or advice would be greatly appreciated.

Best,
Adam

You should use trace_graph_elbo instead of trace_elbo, because the latter can’t handle discrete variables

Thanks for the feedback @yaow. I switched from Trace_ELBO to TraceGraph_ELBO but got the same error.

Tinkering around a bit I noticed that my use of pyro.poutine.block() within AutoDelta seemed inappropriate for a numpyro model. Thus I swapped the following line

guide = AutoDelta(
        poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))
    )

for this

guide = AutoDelta(
        numpyro.handlers.block(model, hide=['x', 'y'])
    )

Making this change eliminates the discrete sites error (ValueError: Continuous inference cannot handle discrete sample site 'x'.). However, I am now getting the following error that seems to be associated with PRNG keys and message names within scan() during svi.run(). Here’s an excerpt of the error trace:

/juno/work/users/weinera2/venv2/lib/python3.7/site-packages/numpyro/distributions/discrete.py in sample(self, key, sample_shape)
    308     def sample(self, key, sample_shape=()):
--> 309         assert is_prng_key(key)
    310         return categorical(key, self.probs, shape=sample_shape + self.batch_shape)

UnfilteredStackTrace: AssertionError

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

Note that this new error now occurs regardless of whether I use TraceGraph_ELBO or Trace_ELBO.

Any feedback for this new error?

I find the cause of the problem, it seems that block hide the sample sites so SVI can’t pass it’s RNGKey to the hided sample sites.
See this post for the solution.

1 Like

Great catch. I’ve gotten block to work using the model in your post using the following code

rng_key = random.PRNGKey(0)
def model():
    with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=rng_key):
        a = numpyro.sample('a', dist.Normal(0,1))
    b = numpyro.sample('b', dist.Dirichlet(concentration=jnp.array([2., 3, 4, 5, 6])))
optim = numpyro.optim.Adam(step_size=1e-3)
elbo = Trace_ELBO()
guide = AutoDelta(model)
svi = SVI(model, guide, optim, elbo)
svi_result = svi.run(rng_key, 1000)

However, using the with block(), seed(key): statement for my above HMM model does not work. The error trace makes it look like there’s something wrong with JAX boolean statements when svi.step() is trying to compute the ELBO. Is it possible this block structure is incorrect when I’m blocking sites within transition_fn() that get passed along to scan()?

Here’s the updated model

def model_1(sequences, lengths, hidden_dim, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    with mask(mask=include_prior):
        probs_x = numpyro.sample(
            "probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
        )
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2),
        )

    def transition_fn(carry, y):
        x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=rng_key):
                    x = numpyro.sample(
                        "x",
                        dist.Categorical(probs_x[x_prev]),
                        infer={"enumerate": "parallel"},
                    )
                    with numpyro.plate("tones", data_dim, dim=-1):
                        numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y)
        return (x, t + 1), None

    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    # NB swapaxes: we move time dimension of `sequences` to the front to scan over it
    scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))

And here’s an abbreviation of the new error trace

/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/svi.py in loss_fn(params)
     67                 elbo.loss(
---> 68                     rng_key, params, model, guide, *args, **kwargs, **static_kwargs
     69                 ),

/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/elbo.py in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
    707         # the ELBO is a lower bound that needs to be maximized.
--> 708         if self.num_particles == 1:
    709             return -single_particle_elbo(rng_key)

/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/jax/core.py in __bool__(self)
    599   def __nonzero__(self): return self.aval._nonzero(self)
--> 600   def __bool__(self): return self.aval._bool(self)
    601   def __int__(self): return self.aval._int(self)

/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/jax/core.py in error(self, arg)
   1113   def error(self, arg):
-> 1114     raise ConcretizationTypeError(arg, fname_context)
   1115   return error

UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[2])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function body_fn at /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/svi.py:334 for jit, this value became a tracer due to JAX operations on these lines:

  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/distributions/transforms.py:938 (__call__)

  operation a:u32[] = convert_element_type[new_dtype=uint32 weak_type=False] b
    from line /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/elbo.py:708 (loss)

  operation a:bool[2] = eq b c
    from line /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/elbo.py:708 (loss)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

Probably it is caused by TraceELBO(rng_key) line? When an error happens, you can trace back to see why we get issue at the statement num_particles == 1. In this case, I guess num_particles is expected to be a concrete positive integer number, rather than rng_key.

Thank you @fehiepsi. Fixing how I initialize TraceGraph_ELBO() and Adam() got rid of all errors. Here’s my final model that works for anyone who runs into the same problems as me. Note that I’m using numpyro version 0.10.0

import argparse
import logging
import os
import time

from jax import random
import jax.numpy as jnp

import numpyro
from numpyro.contrib.control_flow import scan
import numpyro.distributions as dist
from numpyro.examples.datasets import JSB_CHORALES, load_dataset
from numpyro.handlers import mask
from numpyro.infer import HMC, MCMC, NUTS, SVI, Trace_ELBO, TraceGraph_ELBO
from numpyro.ops.indexing import Vindex
from numpyro.infer.autoguide import AutoDelta, AutoGuide
from numpyro.optim import Adam

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

rng_key = random.PRNGKey(0)

#     x[t-1] --> x[t] --> x[t+1]
#        |        |         |
#        V        V         V
#     y[t-1]     y[t]     y[t+1]
#
# This model includes a plate for the data_dim = 44 keys on the piano. This
# model has two "style" parameters probs_x and probs_y that we'll draw from a
# prior. The latent state is x, and the observed state is y.
def model_1(sequences, lengths, hidden_dim, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    with mask(mask=include_prior):
        probs_x = numpyro.sample(
            "probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
        )
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2),
        )

    def transition_fn(carry, y):
        x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=rng_key):
                    x = numpyro.sample(
                        "x",
                        dist.Categorical(probs_x[x_prev]),
                        infer={"enumerate": "parallel"},
                    )
                    with numpyro.plate("tones", data_dim, dim=-1):
                        numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y)
        return (x, t + 1), None

    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    # NB swapaxes: we move time dimension of `sequences` to the front to scan over it
    scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))

def main(num_samples=1000, hidden_dim=16, truncate=None, num_sequences=None,
         kernel='nuts', num_warmup=500, num_chains=1, device='cpu'):

    model = model_1
    
    numpyro.set_platform(device)

    _, fetch = load_dataset(JSB_CHORALES, split="train", shuffle=False)
    lengths, sequences = fetch()
    if num_sequences:
        sequences = sequences[0 : num_sequences]
        lengths = lengths[0 : num_sequences]

    logger.info("-" * 40)
    logger.info("Training {} on {} sequences".format(model.__name__, len(sequences)))

    # find all the notes that are present at least once in the training set
    present_notes = (sequences == 1).sum(0).sum(0) > 0
    # remove notes that are never played (we remove 37/88 notes with default args)
    sequences = sequences[:, :, present_notes]

    if truncate:
        lengths = lengths.clip(0, truncate)
        sequences = sequences[:, : truncate]

    logger.info("Each sequence has shape {}".format(sequences[0].shape))
    logger.info("Starting inference...")

    optim = Adam(step_size=1e-3)
    elbo = TraceGraph_ELBO()
    guide = AutoDelta(model)

    svi = SVI(model, guide, optim, loss=elbo)
    start = time.time()
    svi_result = svi.run(rng_key, 25, sequences, lengths, hidden_dim)
    logger.info("\nSVI elapsed time: {}".format(time.time() - start))

main()
1 Like

I’m getting some weird results for probs_x (the transition matrix) that make me concerned that I’m “blocking” the model from observing the data y and thus being truly fit to the data. Is there an alternative way of blocking discrete sites when building the guide fn that can allow the model to still fit the data?