Jax.where operations on discrete latent variables during backpropagation

Hi everyone,

I’m running into some behaviour I struggle to understand. I have a model with discrete (Bernoulli) latent variables. The relevant code is this:

    assignment = numpyro.sample(
        "assignment",
        dist.Bernoulli(probs=0.1),
    )
    assignments_per_obs = assignment[construct_code]
    assignments_per_obs = jnp.where(
        fixed_assignments, 0, assignments_per_obs
    )

unscaled_strain_expression_per_obs = numpyro.deterministic(
    "unscaled_strain_expression_per_obs",
    jnp.where(
        assignments_per_obs == 0,
        strain_expression[sequence_code],
        strain_expression[alternate_sequence_code],
    ),
)

I then have a very simple likelihood.

I can sample from the prior predictive of this model (and get sensible results) and I can run MCMC on this model as long as I comment out the likelihood. If I leave the likelihood in - I get this error:

ValueError: Output mismatch: Bint[10151] vs Bint[2]

Which I think is related to some operations in the backward pass during the parallel enumeration. The stack trace unfortunately is very difficult to follow but it goes pretty deep in funsor

Any suggestions of what to investigate next?

For enumeration, you can check out Gaussian Mixture Model — NumPyro documentation to make sure plates/broadcasted_computation are defined correctly for your model. If you dont want to use enumeration, you can use DiscreteHMCGibbs sampler.