Sgt timeseires model with zero-inflated likelihood

Hello! I am new to numpyro and I am getting this error, didn’t know if this is an easy fix but hoping for some direction as i have tried casting dtypes to match wherever possible and I am still getting the same error.

def sgt(y, seasonality, future=0):
    # heuristically, standard derivation of Cauchy prior depends on
    # the max value of data
    cauchy_sd = jnp.max(y) / jnp.mean(y)

    # NB: priors' parameters are taken from
    # https://github.com/cbergmeir/Rlgt/blob/master/Rlgt/R/rlgtcontrol.R
    nu = numpyro.sample("nu", dist.Uniform(2, 20))
    powx = numpyro.sample("powx", dist.Uniform(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
    offset_sigma = numpyro.sample(
        "offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)
    )

    coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
    pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
    # pow_trend takes values from -0.5 to 1
    pow_trend = 1.5 * pow_trend_beta - 0.5
    pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))

    level_sm = numpyro.sample("level_sm", dist.Beta(2, 2))
    s_sm = numpyro.sample("s_sm", dist.Beta(2, 2))
    init_s = numpyro.sample("init_s", dist.Cauchy(jnp.zeros(seasonality), cauchy_sd))

    def transition_fn(carry, t):
        level, s, moving_sum = carry
        season = s[0] * level**pow_season
        exp_val = level + coef_trend * level**pow_trend + season
        exp_val = jnp.clip(exp_val, 0)
        # use expected vale when forecasting
        y_t = jnp.where(t >= N, exp_val, y[t])

        # Zero-Inflation Probability
        pi = numpyro.sample("pi", dist.Beta(1, 1))  # Probability of zero-inflation
        
        moving_sum = (
            moving_sum + y[t] - jnp.where(t >= seasonality, y[t - seasonality], 0.0)
        )
        level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)
        level = level_sm * level_p + (1 - level_sm) * level
        level = jnp.clip(level, 0)

        new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
        # repeat s when forecasting
        new_s = jnp.where(t >= N, s[0], new_s)
        s = jnp.concatenate([s[1:], new_s[None]], axis=0)

        y_ = numpyro.sample("y", dist.ZeroInflatedPoisson(rate = exp_val, gate  = pi))

        return (level, s, moving_sum), y_

    N = y.shape[0]
    level_init = y[0]
    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    with numpyro.handlers.condition(data={"y": y[1:]}):
        _, ys = scan(
            transition_fn, (level_init, s_init, moving_sum), jnp.arange(1, N + future)
        )
    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:])

…raise TypeError(f"{what} must have identical types, got\n{diff}.") TypeError: true_fun and false_fun output must have identical types, got DIFFERENT ShapedArray(float32) vs. ShapedArray(int32).

The model fits just fine but as soon as I use “predictive” to predict ahead I get this error:

Thanks,