Hey,
Have been attempting to port the DLT model to Numpyro, however have been finding an issue with it only returning a flat trend, instead of one with saturating growth (in addition to poor sampling).
The code attempting to implement this model can be found in this gist.
And the basic outline can be found below:
def dlt(
y,
global_trend='logistic',
floor=11.75,
cap=9999,
damped_factor=0.1,
n_seasons=52,
future=0
) -> None:
time_delta = 1 # / max(n_seasons, 1)
cauchy_sd = jnp.max(y) / 150
response_sd = jnp.std(y)
## smoothing params
lev_sm = numpyro.sample('lev_sm', dist.Uniform(0, 1))
slp_sm = numpyro.sample('slp_sm', dist.Uniform(0, 1))
## dof
nu = numpyro.sample('nu', dist.Uniform(2, 20))
## residuals
obs_sigma = numpyro.sample('obs_sigma', dist.HalfCauchy(cauchy_sd)) # maybe fix this?
## local trend proportion
lt_coef = numpyro.sample('lt_coef', dist.Uniform(0, 1))
if global_trend == 'logistic':
gl = numpyro.sample('gl', dist.Normal(0, 10))
gb = numpyro.sample('gb', dist.Laplace(0, 1))
elif global_trend in ['linear', 'loglinear']:
# linear + loglinear
gl = numpyro.sample('gl', dist.Normal(0, response_sd))
gb = numpyro.sample('gb', dist.Normal(0, response_sd))
else:
# flat
gl = numpyro.sample('gl', dist.Normal(0, response_sd))
## seasonal params
sea_sm = numpyro.sample('sea_sm', dist.Uniform(0, 1))
# 33% lift is within 1 sd probability
with numpyro.plate('n_seasons', n_seasons):
init_s = numpyro.sample("init_s", dist.Cauchy(0, 0.3))
def transition_fn(carry, t):
level, trend, s = carry
if global_trend == 'logistic':
gt_sum = logistic(floor, cap, gl, gb, time_delta, t)
elif global_trend == 'linear':
gt_sum = linear(gl, gb, time_delta, t)
elif global_trend == 'loglinear':
gt_sum = loglinear(gl, gb, time_delta, t)
else:
gt_sum = gl
lt_sum = jnp.clip(
level + damped_factor * trend,
a_min=0.
)
y_hat = jnp.clip(gt_sum + lt_sum + s[0], a_min=0.)
y_t = jnp.where(t >= T, y_hat, y[t])
level_prev = level
trend_prev = trend
## Update process
level = lev_sm * (y_t - gt_sum - s[0]) + (1 - lev_sm) * lt_sum
trend = slp_sm * (level - level_prev) + (1 - slp_sm) * damped_factor * trend_prev
new_s = sea_sm * (y_t - gt_sum - level) + (1 - sea_sm) * s[0]
new_s = jnp.where(t >= T, s[0], new_s)
season = jnp.concatenate([s[1:], new_s[None]], axis=0)
y_ = numpyro.sample('y', dist.StudentT(nu, y_hat, obs_sigma))
return (level, trend, season), y_
T = y.shape[0]
level_init, trend_init = y[0], y[0]
s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
carry = (level_init, trend_init, s_init)
with numpyro.handlers.condition(data={'y': y[1:]}):
_, ys = scan(
transition_fn,
carry,
jnp.arange(1, T + future)
)
if future > 0:
numpyro.deterministic("y_forecast", ys[-future:])
Any help in debugging this model would be much appreciated!