DLT Model - Logistic Trend

Hey,

Have been attempting to port the DLT model to Numpyro, however have been finding an issue with it only returning a flat trend, instead of one with saturating growth (in addition to poor sampling).

The code attempting to implement this model can be found in this gist.

And the basic outline can be found below:

def dlt(
    y, 
    global_trend='logistic', 
    floor=11.75, 
    cap=9999, 
    damped_factor=0.1,
    n_seasons=52, 
    future=0
) -> None:

    time_delta = 1 # / max(n_seasons, 1)
    cauchy_sd = jnp.max(y) / 150
    response_sd = jnp.std(y) 

    ## smoothing params
    lev_sm = numpyro.sample('lev_sm', dist.Uniform(0, 1))
    slp_sm = numpyro.sample('slp_sm', dist.Uniform(0, 1))

    ## dof
    nu = numpyro.sample('nu', dist.Uniform(2, 20))

    ## residuals
    obs_sigma = numpyro.sample('obs_sigma', dist.HalfCauchy(cauchy_sd)) # maybe fix this?

    ## local trend proportion
    lt_coef = numpyro.sample('lt_coef', dist.Uniform(0, 1))

    if global_trend == 'logistic':
        gl = numpyro.sample('gl', dist.Normal(0, 10))
        gb = numpyro.sample('gb', dist.Laplace(0, 1))
    elif global_trend in ['linear', 'loglinear']:
        # linear + loglinear
        gl = numpyro.sample('gl', dist.Normal(0, response_sd))
        gb = numpyro.sample('gb', dist.Normal(0, response_sd))
    else:
        # flat
        gl = numpyro.sample('gl', dist.Normal(0, response_sd))

    ## seasonal params
    sea_sm = numpyro.sample('sea_sm', dist.Uniform(0, 1))

    # 33% lift is within 1 sd probability
    with numpyro.plate('n_seasons', n_seasons):
        init_s = numpyro.sample("init_s", dist.Cauchy(0, 0.3))

    def transition_fn(carry, t):
        level, trend, s = carry

        if global_trend == 'logistic':
            gt_sum = logistic(floor, cap, gl, gb, time_delta, t)
        elif global_trend == 'linear':
            gt_sum = linear(gl, gb, time_delta, t)
        elif global_trend == 'loglinear':
            gt_sum = loglinear(gl, gb, time_delta, t)
        else:
            gt_sum = gl

        lt_sum = jnp.clip(
            level + damped_factor * trend, 
            a_min=0.
        )
        y_hat = jnp.clip(gt_sum + lt_sum + s[0], a_min=0.)
        y_t = jnp.where(t >= T, y_hat, y[t])

        level_prev = level
        trend_prev = trend

        ## Update process
        level = lev_sm * (y_t - gt_sum - s[0]) + (1 - lev_sm) * lt_sum
        trend = slp_sm * (level - level_prev) + (1 - slp_sm) * damped_factor * trend_prev
        new_s = sea_sm * (y_t - gt_sum - level) + (1 - sea_sm) * s[0]
        new_s = jnp.where(t >= T, s[0], new_s)
        
        season = jnp.concatenate([s[1:], new_s[None]], axis=0)

        y_ = numpyro.sample('y', dist.StudentT(nu, y_hat, obs_sigma))

        return (level, trend, season), y_ 

    T = y.shape[0]

    level_init, trend_init = y[0], y[0]

    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)

    carry = (level_init, trend_init, s_init)

    with numpyro.handlers.condition(data={'y': y[1:]}):
        _, ys = scan(
            transition_fn,
            carry,
            jnp.arange(1, T + future)
        )

    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:])

Any help in debugging this model would be much appreciated!