Thanks. I can confirm that this (simplified) model doesn’t work
def my_model(L, pi, theta_mean, theta_std):
with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
theta_5 = numpyro.sample("theta_5",
dist.TransformedDistribution(
dist.Normal(loc=jnp.array(theta_mean[c, 2]), scale=jnp.array(theta_std[c, 4])),
transforms.ExpTransform()
))
but doing the jnp.exp
change works
def crop_inference_model(L, pi, theta_mean, theta_std):
with numpyro.plate("L", L):
c = numpyro.sample("c", dist.Categorical(jnp.array(pi)), infer={'enumerate': 'parallel'})
theta_5_raw = numpyro.sample("theta_5_raw", dist.Normal(loc=jnp.array(theta_mean[c, 2]), scale=jnp.array(theta_std[c, 4])))
theta_5 = numpyro.deterministic("theta_5", jnp.exp(theta_5_raw))
Stan does not support enumeration out of the box I think. I wrote code to marginalize c
(using log-sum-exp
).