Using SVI for enumerated HMM

Great catch. I’ve gotten block to work using the model in your post using the following code

rng_key = random.PRNGKey(0)
def model():
    with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=rng_key):
        a = numpyro.sample('a', dist.Normal(0,1))
    b = numpyro.sample('b', dist.Dirichlet(concentration=jnp.array([2., 3, 4, 5, 6])))
optim = numpyro.optim.Adam(step_size=1e-3)
elbo = Trace_ELBO()
guide = AutoDelta(model)
svi = SVI(model, guide, optim, elbo)
svi_result = svi.run(rng_key, 1000)

However, using the with block(), seed(key): statement for my above HMM model does not work. The error trace makes it look like there’s something wrong with JAX boolean statements when svi.step() is trying to compute the ELBO. Is it possible this block structure is incorrect when I’m blocking sites within transition_fn() that get passed along to scan()?

Here’s the updated model

def model_1(sequences, lengths, hidden_dim, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    with mask(mask=include_prior):
        probs_x = numpyro.sample(
            "probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
        )
        probs_y = numpyro.sample(
            "probs_y",
            dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2),
        )

    def transition_fn(carry, y):
        x_prev, t = carry
        with numpyro.plate("sequences", num_sequences, dim=-2):
            with mask(mask=(t < lengths)[..., None]):
                with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=rng_key):
                    x = numpyro.sample(
                        "x",
                        dist.Categorical(probs_x[x_prev]),
                        infer={"enumerate": "parallel"},
                    )
                    with numpyro.plate("tones", data_dim, dim=-1):
                        numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y)
        return (x, t + 1), None

    x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
    # NB swapaxes: we move time dimension of `sequences` to the front to scan over it
    scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))

And here’s an abbreviation of the new error trace

/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/svi.py in loss_fn(params)
     67                 elbo.loss(
---> 68                     rng_key, params, model, guide, *args, **kwargs, **static_kwargs
     69                 ),

/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/elbo.py in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
    707         # the ELBO is a lower bound that needs to be maximized.
--> 708         if self.num_particles == 1:
    709             return -single_particle_elbo(rng_key)

/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/jax/core.py in __bool__(self)
    599   def __nonzero__(self): return self.aval._nonzero(self)
--> 600   def __bool__(self): return self.aval._bool(self)
    601   def __int__(self): return self.aval._int(self)

/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/jax/core.py in error(self, arg)
   1113   def error(self, arg):
-> 1114     raise ConcretizationTypeError(arg, fname_context)
   1115   return error

UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[2])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function body_fn at /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/svi.py:334 for jit, this value became a tracer due to JAX operations on these lines:

  operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
    from line /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/distributions/transforms.py:938 (__call__)

  operation a:u32[] = convert_element_type[new_dtype=uint32 weak_type=False] b
    from line /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/elbo.py:708 (loss)

  operation a:bool[2] = eq b c
    from line /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/elbo.py:708 (loss)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError