Numpyro missing value imputation with continuous missing variable


I am trying to implement a MNAR missing value model with the reason for missingness explicitly implemented, as shown in the numpyro documentation.

My issue is that my missing variable is continuous, whereas the example in the documentation is discrete. Specifically, I am not sure how to “cancel out values that are not equal to observed values” for a continuous variable, as is done for the discrete case below:

# cancel out enumerated values that are not equal to observed values
log_prob = jnp.where(A_isobs & (Aimp != A), -inf, log_prob)

I would assume we could use a distance measurement here, but I was hoping I could get a bit of guidance since I am not positive exactly how the manual log prob calculation is performed.

On a side note, could this model be implemented in base pyro instead of numpyro if desired?

Thank you!


Maybe you want to follow this tutorial instead Bayesian Imputation — NumPyro documentation

I believe the approach should also work in Pyro.

cancel out values that are not equal to observed values

For continuous variables, you can use strong priors/likelihoods like Normal(..., 1e-4).