Great catch. I’ve gotten block
to work using the model in your post using the following code
rng_key = random.PRNGKey(0)
def model():
with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=rng_key):
a = numpyro.sample('a', dist.Normal(0,1))
b = numpyro.sample('b', dist.Dirichlet(concentration=jnp.array([2., 3, 4, 5, 6])))
optim = numpyro.optim.Adam(step_size=1e-3)
elbo = Trace_ELBO()
guide = AutoDelta(model)
svi = SVI(model, guide, optim, elbo)
svi_result = svi.run(rng_key, 1000)
However, using the with block(), seed(key):
statement for my above HMM model does not work. The error trace makes it look like there’s something wrong with JAX boolean statements when svi.step() is trying to compute the ELBO. Is it possible this block structure is incorrect when I’m blocking sites within transition_fn()
that get passed along to scan()
?
Here’s the updated model
def model_1(sequences, lengths, hidden_dim, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
with mask(mask=include_prior):
probs_x = numpyro.sample(
"probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
)
probs_y = numpyro.sample(
"probs_y",
dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2),
)
def transition_fn(carry, y):
x_prev, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
with numpyro.handlers.block(), numpyro.handlers.seed(rng_seed=rng_key):
x = numpyro.sample(
"x",
dist.Categorical(probs_x[x_prev]),
infer={"enumerate": "parallel"},
)
with numpyro.plate("tones", data_dim, dim=-1):
numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y)
return (x, t + 1), None
x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
# NB swapaxes: we move time dimension of `sequences` to the front to scan over it
scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
And here’s an abbreviation of the new error trace
/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/svi.py in loss_fn(params)
67 elbo.loss(
---> 68 rng_key, params, model, guide, *args, **kwargs, **static_kwargs
69 ),
/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/elbo.py in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
707 # the ELBO is a lower bound that needs to be maximized.
--> 708 if self.num_particles == 1:
709 return -single_particle_elbo(rng_key)
/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/jax/core.py in __bool__(self)
599 def __nonzero__(self): return self.aval._nonzero(self)
--> 600 def __bool__(self): return self.aval._bool(self)
601 def __int__(self): return self.aval._int(self)
/juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/jax/core.py in error(self, arg)
1113 def error(self, arg):
-> 1114 raise ConcretizationTypeError(arg, fname_context)
1115 return error
UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[2])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function body_fn at /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/svi.py:334 for jit, this value became a tracer due to JAX operations on these lines:
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b
from line /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/distributions/transforms.py:938 (__call__)
operation a:u32[] = convert_element_type[new_dtype=uint32 weak_type=False] b
from line /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/elbo.py:708 (loss)
operation a:bool[2] = eq b c
from line /juno/work/shah/users/weinera2/projects/scdna_replication_tools/venv2/lib/python3.7/site-packages/numpyro/infer/elbo.py:708 (loss)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError