Hi everyone,
I’m running into some behaviour I struggle to understand. I have a model with discrete (Bernoulli) latent variables. The relevant code is this:
assignment = numpyro.sample(
"assignment",
dist.Bernoulli(probs=0.1),
)
assignments_per_obs = assignment[construct_code]
assignments_per_obs = jnp.where(
fixed_assignments, 0, assignments_per_obs
)
unscaled_strain_expression_per_obs = numpyro.deterministic(
"unscaled_strain_expression_per_obs",
jnp.where(
assignments_per_obs == 0,
strain_expression[sequence_code],
strain_expression[alternate_sequence_code],
),
)
I then have a very simple likelihood.
I can sample from the prior predictive of this model (and get sensible results) and I can run MCMC on this model as long as I comment out the likelihood. If I leave the likelihood in - I get this error:
ValueError: Output mismatch: Bint[10151] vs Bint[2]
Which I think is related to some operations in the backward pass during the parallel enumeration. The stack trace unfortunately is very difficult to follow but it goes pretty deep in funsor
Any suggestions of what to investigate next?