Model works with MCMC, but produces an error with log_density

Hi all,

I have a model using a mixture distribution in the plate, and it is causing some issues with log_density even though the model samples fine using NUTS. It is quite a big model, but a basic example with the part of the code giving an error looks like:

def signal_model(data1, data2, data_len, ...):
    signal_rate = numpyro.sample('signal_rate', ndist.Uniform(0.0, 2000.))
    bg_rate_scaled = numpyro.sample('bg_rate_scaled', ndist.Normal(1, 0.1))
    bg_rate = bg_rate_scaled*mean_bg_rate
    observed_signal = numpyro.sample('observed_signal', ndist.Poisson(signal_rate + bg_rate), obs=data_len)
    categorical_dist_for_E = ndist.Categorical(probs=jnp.array([bg_rate, signal_rate])/(bg_rate + signal_rate))
    E_component_distributions = [
        ndist.Uniform(0.0, 2000.),
        ndist.TruncatedNormal(1000.0, 0.5, low=0., high=2000.)
    with numpyro.plate('data', data_len) as ind:
        E = numpyro.sample('E', ndist.MixtureGeneral(categorical_dist_for_E, E_component_distributions))
        obs1 = numpyro.sample('obs1', ndist.Normal(E, jnp.sqrt(E))
        obs2 = numpyro.sample('obs2', ndist.Normal(E, jnp.sqrt(E))

When I try to run:

joint_fn = partial(infer.util.log_density, signal_model, (data1, data2, len(data1)), {})

I get:

ValueError: The number of elements in 'component_distributions' must match the mixture size; expected 900, got 2

It seems to be some issue with broadcasting caused by plate, but why does it not affect MCMC, and how can I resolve this? I tried to move the definitions of the component distributions into the plate but that doesn’t make a difference, and I don’t think I expected a difference anyway. Thanks in advance for any help!

I figured it out, it is a complete brain fart. mcmc_sampler.get_samples() gets all samples instead of 1 sample, and that’s why the dimensions are not agreeing. I just needed to select a single sample from the data dictionary.

last_sample = {}
for k in sample_dict:
    last_sample[k] = sample_dict[k][-1]