Hi all,
I have a model using a mixture distribution in the plate, and it is causing some issues with log_density even though the model samples fine using NUTS. It is quite a big model, but a basic example with the part of the code giving an error looks like:
def signal_model(data1, data2, data_len, ...):
signal_rate = numpyro.sample('signal_rate', ndist.Uniform(0.0, 2000.))
bg_rate_scaled = numpyro.sample('bg_rate_scaled', ndist.Normal(1, 0.1))
bg_rate = bg_rate_scaled*mean_bg_rate
observed_signal = numpyro.sample('observed_signal', ndist.Poisson(signal_rate + bg_rate), obs=data_len)
categorical_dist_for_E = ndist.Categorical(probs=jnp.array([bg_rate, signal_rate])/(bg_rate + signal_rate))
E_component_distributions = [
ndist.Uniform(0.0, 2000.),
ndist.TruncatedNormal(1000.0, 0.5, low=0., high=2000.)
]
with numpyro.plate('data', data_len) as ind:
E = numpyro.sample('E', ndist.MixtureGeneral(categorical_dist_for_E, E_component_distributions))
obs1 = numpyro.sample('obs1', ndist.Normal(E, jnp.sqrt(E))
obs2 = numpyro.sample('obs2', ndist.Normal(E, jnp.sqrt(E))
When I try to run:
joint_fn = partial(infer.util.log_density, signal_model, (data1, data2, len(data1)), {})
joint_fn(mcmc_sampler.get_samples())
I get:
ValueError: The number of elements in 'component_distributions' must match the mixture size; expected 900, got 2
It seems to be some issue with broadcasting caused by plate, but why does it not affect MCMC, and how can I resolve this? I tried to move the definitions of the component distributions into the plate but that doesn’t make a difference, and I don’t think I expected a difference anyway. Thanks in advance for any help!