Different shaped array inside of scan function

Hey,

I’m trying to replicate this blog post in NumPyro but am running into some issues when using the dist.NegativeBinomial2 likelihood inside of the scan function.

The model seems to be sampling okay the majority of the time (although has divergences) but when I come to use Predictive during the forecast step, I encounter this error:

TypeError: true_fun and false_fun output must have identical types, got
DIFFERENT ShapedArray(float32[]) vs. ShapedArray(int32[]).

I’ve tried printing out all the variables and can’t find what’s being changed from a float32 to an int32 - for reference this same model specification works fine when using a dist.Normal.

I’ve also tried this using the likelihood outside of the scan function with:

_, mu = scan(
    transition_fn,
    carry,
    jnp.arange(1, T + future)
)

if training:
    sample('y', dist.NegativeBinomial2(mu + eps, alpha), obs=y[1:])
else:
    sample('y_forecast', dist.NegativeBinomial2(mu + eps, alpha))

But I’ve found the predicted values from the scan function to degrade / not carry through into the future for more than a few steps (hence the need for including it in the scan_fn).

I was wondering if anyone had any ideas about where I was going wrong?

Notebook to replicate can be found here.