I would like to use the Seasonal, global trend model, as documented here, with a hierarchical model for multiple time series.
I’ve started by allowing multiple time series to be passed in, but am not able to get the model to converge consistently. The model I am using:
def sgt(y, seasonality, future=0):
# heuristically, standard derivation of Cauchy prior depends on
# the max value of data
cauchy_sd = jnp.max(y) / 150
with numpyro.plate('plate_titles', y.shape[-1]):
# NB: priors' parameters are taken from
# https://github.com/cbergmeir/Rlgt/blob/master/Rlgt/R/rlgtcontrol.R
nu = numpyro.sample("nu", dist.Uniform(2, 20))
powx = numpyro.sample("powx", dist.Uniform(0, 1))
sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
offset_sigma = numpyro.sample(
"offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)
)
coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
# pow_trend takes values from -0.5 to 1
pow_trend = 1.5 * pow_trend_beta - 0.5
pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))
level_sm = numpyro.sample("level_sm", dist.Beta(1, 2))
s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1))
init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3))
def transition_fn(carry, t):
level, s, moving_sum = carry
season = s[0] * level**pow_season
exp_val = level + coef_trend * level**pow_trend + season
exp_val = jnp.clip(exp_val, a_min=0)
mu = numpyro.deterministic('mu', exp_val)
# use expected vale when forecasting
y_t = jnp.where(t >= N, mu, y[t])
moving_sum = (
moving_sum + y[t] - jnp.where(t >= seasonality, y[t - seasonality], 0.0)
)
level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)
level = level_sm * level_p + (1 - level_sm) * level
level = jnp.clip(level, a_min=0)
new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
# repeat s when forecasting
new_s = jnp.where(t >= N, s[0], new_s)
s = jnp.concatenate([s[1:], new_s[None]], axis=0)
omega = sigma * exp_val**powx + offset_sigma
y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega))
return (level, s, moving_sum), y_
N = y.shape[0]
level_init = y[0]
s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
moving_sum = level_init
with numpyro.handlers.condition(data={"y": y[1:]}):
_, ys = scan(
transition_fn, (level_init, s_init, moving_sum), jnp.arange(1, N + future)
)
if future > 0:
numpyro.deterministic("y_forecast", ys[-future:, ...])
Will not converge when I pass in 5 time series.
ts_train_scaled = ts_train / ts_train.max(axis=0)
y_train = jnp.array(ts_train_scaled.values, dtype=jnp.float32)
nuts_kernel = NUTS(sgt, target_accept_prob=0.95)
mcmc = MCMC(nuts_kernel, num_samples=3000, num_warmup=5000, num_chains=3)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, y_train, seasonality=12)
results in the mcmc.print_summary()
:
mean std median 5.0% 95.0% n_eff r_hat
coef_trend[0] 0.36 1.12 -0.31 -0.55 1.94 1.50 112418.62
coef_trend[1] -0.16 1.08 0.41 -1.67 0.78 1.50 122840.10
coef_trend[2] 0.88 1.09 1.37 -0.64 1.89 1.50 76501.25
coef_trend[3] 0.46 0.57 0.62 -0.31 1.06 1.50 61631.89
coef_trend[4] 0.74 0.86 0.30 -0.03 1.94 1.50 51004.46
init_s[0,0] 0.02 1.46 -0.37 -1.54 1.97 1.50 89578.06
init_s[0,1] -0.76 0.31 -0.77 -1.13 -0.38 1.50 29751.61
init_s[0,2] 0.84 0.78 0.36 0.23 1.95 1.50 120843.83
init_s[0,3] -0.16 0.69 -0.43 -0.83 0.79 1.50 70123.95
init_s[0,4] 0.45 0.78 0.72 -0.60 1.25 1.50 71473.78
init_s[1,0] 0.15 0.74 -0.36 -0.38 1.19 nan 110405.20
init_s[1,1] 0.55 1.16 1.11 -1.06 1.61 1.50 124990.59
init_s[1,2] 0.14 1.19 0.29 -1.39 1.52 1.50 212772.50
init_s[1,3] -0.27 1.08 0.18 -1.75 0.77 1.50 430631.81
init_s[1,4] 0.77 0.71 0.95 -0.17 1.53 1.50 65490.65
init_s[2,0] -0.16 1.06 0.30 -1.62 0.85 1.50 170010.61
init_s[2,1] -0.04 0.81 0.24 -1.14 0.77 1.50 122270.51
init_s[2,2] 0.71 0.49 0.79 0.07 1.27 nan 75560.48
init_s[2,3] -0.60 1.10 -1.36 -1.40 0.96 1.50 118743.43
init_s[2,4] -0.84 0.75 -0.55 -1.86 -0.11 1.50 65711.94
init_s[3,0] -0.07 0.98 0.40 -1.44 0.83 1.50 126212.38
init_s[3,1] 0.01 1.17 -0.52 -1.08 1.63 1.50 66828.71
init_s[3,2] -0.54 1.27 -1.23 -1.64 1.24 1.50 99708.92
init_s[3,3] -0.54 1.45 -1.25 -1.85 1.47 1.50 95123.71
init_s[3,4] 0.76 0.55 1.02 -0.01 1.27 1.50 84896.32
init_s[4,0] 0.76 0.49 0.85 0.12 1.30 1.50 97484.64
init_s[4,1] 1.25 0.70 1.71 0.26 1.78 1.50 60106.23
init_s[4,2] 0.92 0.52 0.93 0.27 1.55 1.50 84627.20
init_s[4,3] -1.28 0.57 -1.43 -1.89 -0.51 1.50 55891.23
init_s[4,4] -0.34 1.30 -1.00 -1.50 1.47 1.50 245384.53
init_s[5,0] -0.06 0.91 -0.67 -0.73 1.23 1.50 84004.56
init_s[5,1] 0.06 1.55 0.35 -1.96 1.79 1.50 92012.52
init_s[5,2] 0.06 1.30 -0.62 -1.09 1.89 1.50 63699.32
init_s[5,3] 0.33 0.62 0.28 -0.40 1.11 1.50 68776.95
init_s[5,4] 1.14 0.79 1.42 0.07 1.95 1.50 160233.02
init_s[6,0] 1.38 0.39 1.18 1.03 1.93 1.50 30507.22
init_s[6,1] 0.13 1.18 -0.47 -0.91 1.78 1.50 58611.97
init_s[6,2] -1.06 0.66 -1.36 -1.67 -0.15 1.50 67687.60
init_s[6,3] 0.86 0.80 0.78 -0.08 1.87 1.50 103087.18
init_s[6,4] 0.38 0.78 0.08 -0.38 1.45 1.50 269876.09
init_s[7,0] -0.15 1.28 0.47 -1.93 1.01 1.50 143029.36
init_s[7,1] 0.17 0.77 -0.13 -0.57 1.23 1.50 76967.12
init_s[7,2] -0.24 0.32 -0.31 -0.60 0.18 1.50 50018.43
init_s[7,3] -1.41 0.52 -1.77 -1.78 -0.67 1.50 125770.87
init_s[7,4] -0.94 0.91 -1.17 -1.93 0.28 1.50 62509.69
init_s[8,0] -0.56 0.89 -0.65 -1.60 0.58 1.50 101498.48
init_s[8,1] 0.59 0.56 0.91 -0.19 1.05 1.50 54177.80
init_s[8,2] -0.14 0.75 -0.15 -1.05 0.78 1.50 64354.41
init_s[8,3] 1.46 0.62 1.84 0.58 1.95 1.50 37548.76
init_s[8,4] -0.07 1.15 0.28 -1.62 1.14 1.50 167859.58
init_s[9,0] -0.42 1.09 -0.41 -1.76 0.92 1.50 172441.55
init_s[9,1] 0.74 0.74 0.91 -0.25 1.55 1.50 88531.73
init_s[9,2] 0.90 0.87 1.37 -0.33 1.65 1.50 113046.64
init_s[9,3] -0.33 0.75 -0.40 -1.21 0.61 1.50 89287.06
init_s[9,4] -0.29 0.29 -0.35 -0.61 0.10 1.50 51703.10
init_s[10,0] 0.61 1.18 1.28 -1.06 1.60 1.50 102308.13
init_s[10,1] 0.19 1.54 0.73 -1.91 1.76 1.50 187399.03
init_s[10,2] 0.34 1.08 0.60 -1.10 1.50 1.50 108870.59
init_s[10,3] -1.41 0.57 -1.74 -1.88 -0.61 1.50 52140.90
init_s[10,4] -0.23 1.08 0.22 -1.73 0.81 1.50 112729.67
init_s[11,0] 1.49 0.34 1.65 1.02 1.80 1.50 20130.48
init_s[11,1] -0.52 0.85 -0.19 -1.69 0.32 1.50 182218.47
init_s[11,2] 0.38 1.38 0.57 -1.39 1.97 1.50 227164.64
init_s[11,3] 0.68 0.85 0.20 -0.05 1.87 1.50 174068.81
init_s[11,4] -0.52 1.05 -0.97 -1.53 0.94 1.50 108386.41
level_sm[0] 0.55 0.27 0.64 0.18 0.83 1.50 70060.18
level_sm[1] 0.60 0.17 0.66 0.36 0.77 1.50 39629.94
level_sm[2] 0.64 0.34 0.87 0.15 0.88 1.50 48925.53
level_sm[3] 0.66 0.24 0.83 0.32 0.83 1.50 69822.33
level_sm[4] 0.55 0.12 0.47 0.47 0.72 1.50 47265.89
nu[0] 14.82 2.41 16.12 11.44 16.89 1.50 25090.27
nu[1] 13.65 3.78 15.37 8.41 17.18 1.50 19929.25
nu[2] 8.23 2.96 6.43 5.87 12.40 1.50 30653.68
nu[3] 8.50 4.62 6.30 4.27 14.93 1.50 69428.32
nu[4] 12.01 3.89 9.55 8.96 17.50 1.50 27090.38
offset_sigma[0] 0.16 0.02 0.15 0.14 0.18 1.50 7675.23
offset_sigma[1] 3.10 2.93 1.99 0.20 7.12 1.50 126687.43
offset_sigma[2] 0.62 0.30 0.67 0.23 0.97 1.50 47580.12
offset_sigma[3] 3.29 2.38 3.49 0.27 6.10 1.50 57718.11
offset_sigma[4] 2.65 1.72 3.54 0.24 4.17 1.50 62632.24
pow_season[0] 0.51 0.25 0.60 0.16 0.75 1.50 82478.45
pow_season[1] 0.51 0.21 0.44 0.30 0.79 1.50 32873.70
pow_season[2] 0.51 0.17 0.43 0.35 0.75 1.50 50344.23
pow_season[3] 0.56 0.22 0.46 0.35 0.87 1.50 70952.66
pow_season[4] 0.62 0.27 0.76 0.24 0.86 1.50 57755.46
pow_trend_beta[0] 0.24 0.12 0.18 0.13 0.41 1.50 42623.52
pow_trend_beta[1] 0.65 0.08 0.62 0.56 0.76 1.50 15979.00
pow_trend_beta[2] 0.35 0.10 0.31 0.25 0.49 1.50 41666.16
pow_trend_beta[3] 0.32 0.16 0.24 0.19 0.54 1.50 35616.77
pow_trend_beta[4] 0.35 0.14 0.36 0.17 0.51 1.50 27855.96
powx[0] 0.54 0.22 0.39 0.39 0.85 1.50 79356.06
powx[1] 0.44 0.29 0.38 0.12 0.82 1.50 58651.75
powx[2] 0.36 0.17 0.43 0.13 0.53 1.50 46921.92
powx[3] 0.66 0.27 0.82 0.27 0.88 1.50 49579.72
powx[4] 0.41 0.30 0.28 0.13 0.82 1.50 64655.02
s_sm[0] 0.56 0.27 0.58 0.21 0.88 1.50 38808.21
s_sm[1] 0.46 0.25 0.42 0.18 0.78 1.50 43660.77
s_sm[2] 0.68 0.27 0.86 0.30 0.87 1.50 44946.68
s_sm[3] 0.64 0.08 0.62 0.56 0.74 1.50 9107.75
s_sm[4] 0.52 0.21 0.47 0.29 0.79 1.50 23846.79
sigma[0] 1.53 0.34 1.35 1.24 2.01 1.50 18716.86
sigma[1] 2.62 3.12 0.65 0.17 7.03 1.50 151841.22
sigma[2] 1.62 1.48 1.05 0.17 3.65 1.50 155257.16
sigma[3] 1.51 1.83 0.28 0.15 4.10 1.50 619547.94
sigma[4] 1.02 1.17 0.22 0.16 2.68 1.50 74382.05
Number of divergences: 9000
However, when I run with 1 or 2 time series (I’ve confirmed that each of the 5 time series converges independently), the model converges:
mcmc.run(rng_key, y_train[..., :2], seasonality=12)
Results in:
mean std median 5.0% 95.0% n_eff r_hat
coef_trend[0] -0.02 0.02 -0.02 -0.05 0.00 4132.93 1.00
coef_trend[1] -0.00 0.01 -0.00 -0.02 0.02 4955.30 1.00
init_s[0,0] 0.09 0.08 0.09 -0.04 0.22 6795.17 1.00
init_s[0,1] -0.07 0.06 -0.07 -0.17 0.03 8379.66 1.00
init_s[1,0] 0.03 0.05 0.03 -0.06 0.12 8749.67 1.00
init_s[1,1] -0.02 0.05 -0.02 -0.10 0.06 8393.98 1.00
init_s[2,0] 0.06 0.05 0.06 -0.03 0.14 6817.32 1.00
init_s[2,1] -0.02 0.05 -0.02 -0.10 0.07 7933.21 1.00
init_s[3,0] -0.01 0.06 -0.01 -0.10 0.09 7448.22 1.00
init_s[3,1] -0.09 0.07 -0.10 -0.21 0.02 5832.02 1.00
init_s[4,0] 0.04 0.05 0.04 -0.05 0.12 7174.37 1.00
init_s[4,1] 0.05 0.07 0.05 -0.05 0.17 6700.69 1.00
init_s[5,0] 0.06 0.05 0.06 -0.03 0.14 8191.24 1.00
init_s[5,1] 0.02 0.06 0.01 -0.09 0.12 8123.27 1.00
init_s[6,0] 0.06 0.05 0.06 -0.02 0.14 6699.31 1.00
init_s[6,1] 0.05 0.05 0.04 -0.04 0.13 8430.82 1.00
init_s[7,0] 0.05 0.05 0.06 -0.03 0.14 6489.25 1.00
init_s[7,1] -0.03 0.05 -0.03 -0.12 0.05 7872.68 1.00
init_s[8,0] 0.03 0.05 0.03 -0.06 0.12 6549.39 1.00
init_s[8,1] -0.02 0.06 -0.02 -0.11 0.07 10253.20 1.00
init_s[9,0] -0.05 0.06 -0.05 -0.16 0.05 7296.68 1.00
init_s[9,1] -0.05 0.06 -0.05 -0.14 0.04 8898.04 1.00
init_s[10,0] 0.00 0.06 0.00 -0.10 0.10 8708.00 1.00
init_s[10,1] -0.08 0.06 -0.08 -0.17 0.02 7703.48 1.00
init_s[11,0] 0.01 0.06 0.01 -0.08 0.11 6926.03 1.00
init_s[11,1] -0.04 0.07 -0.04 -0.15 0.06 7641.93 1.00
level_sm[0] 0.66 0.16 0.67 0.41 0.92 7396.50 1.00
level_sm[1] 0.49 0.22 0.50 0.16 0.89 3047.35 1.00
nu[0] 10.83 4.96 10.58 3.29 18.84 11557.78 1.00
nu[1] 6.78 4.60 5.05 2.00 14.36 5167.74 1.00
offset_sigma[0] 0.03 0.02 0.04 0.00 0.06 4553.84 1.00
offset_sigma[1] 0.04 0.03 0.04 0.00 0.08 3769.03 1.00
pow_season[0] 0.57 0.28 0.61 0.14 1.00 13224.98 1.00
pow_season[1] 0.74 0.22 0.80 0.42 1.00 8854.54 1.00
pow_trend_beta[0] 0.51 0.29 0.52 0.10 1.00 13193.83 1.00
pow_trend_beta[1] 0.51 0.29 0.52 0.09 0.99 15865.08 1.00
powx[0] 0.46 0.29 0.45 0.00 0.88 16280.81 1.00
powx[1] 0.47 0.29 0.46 0.00 0.89 14675.20 1.00
s_sm[0] 0.68 0.17 0.69 0.45 0.97 4970.09 1.00
s_sm[1] 0.08 0.09 0.06 0.00 0.19 5599.77 1.00
sigma[0] 0.03 0.02 0.02 0.00 0.06 4467.75 1.00
sigma[1] 0.04 0.04 0.03 0.00 0.09 3527.14 1.00
Number of divergences: 0
Here is the data for the 5 time series:
[[0.93371051, 0.64300408, 0.95567756, 0.6979538 , 0.76896892],
[0.87502375, 0.62234911, 0.9340615 , 0.65774621, 0.74398795],
[0.89842731, 0.66488988, 0.9345194 , 0.67799256, 0.75796578],
[0.78622806, 0.55800573, 0.93064537, 0.68454134, 0.75630942],
[0.82913281, 0.74113209, 0.94648535, 0.71271235, 0.76568955],
[0.88351516, 0.73015649, 0.95547244, 0.7193235 , 0.78208796],
[0.89374491, 0.68741038, 0.95505926, 0.72206039, 0.78485916],
[0.89743164, 0.61631777, 0.95341337, 0.71953237, 0.78877879],
[0.85370598, 0.64081714, 0.94064708, 0.70682795, 0.79324588],
[0.74987143, 0.62753388, 0.95383012, 0.69583601, 0.78315475],
[0.91864621, 0.56334882, 0.94146711, 0.69433016, 0.76187362],
[0.94027137, 0.6762718 , 0.94586243, 0.69045188, 0.78483132],
[1. , 0.66178637, 0.99895423, 0.71463336, 0.82113993],
[0.975537 , 0.6785507 , 0.98172716, 0.71503008, 0.78962486],
[0.91650188, 0.67191343, 0.97854702, 0.71913111, 0.80434578],
[0.95943085, 0.88703244, 0.9726632 , 0.71317424, 0.80464693],
[0.95272826, 0.67341532, 0.97707868, 0.7345105 , 0.81360763],
[0.94321012, 0.64264604, 0.97915435, 0.73943241, 0.80511251],
[0.95053464, 0.7325226 , 0.97722587, 0.73517754, 0.8020118 ],
[0.94462703, 0.68118902, 0.97556803, 0.73248145, 0.80133707],
[0.92150266, 0.60473359, 0.95902804, 0.71087999, 0.77705044],
[0.94592757, 0.66135273, 0.97174026, 0.70898213, 0.78093387],
[0.90923205, 0.59334388, 0.95620248, 0.70364048, 0.75438812],
[0.9369869 , 0.58406081, 0.96382202, 0.71188881, 0.78686109],
[0.98934096, 0.59010947, 1. , 0.73600896, 0.82589099],
[0.88738814, 0.70992145, 0.98016696, 0.71784463, 0.78524167],
[0.89901598, 0.68150995, 0.98362083, 0.72482206, 0.80548564],
[0.77925889, 0.66745375, 0.98152118, 0.70067359, 0.80250356],
[0.98022655, 1. , 0.9835584 , 0.72209406, 0.79753561],
[0.93937041, 0.96445179, 0.94316288, 0.7064905 , 0.75250679],
[0.95481498, 0.81383263, 0.96098556, 0.71576171, 0.77762344],
[0.90698278, 0.70355008, 0.97850512, 0.73540242, 0.75413136],
[0.92627557, 0.74518375, 0.97465047, 0.71159379, 0.76630128],
[0.93453102, 0.68191003, 0.9681339 , 0.70612449, 0.77044084],
[0.90323614, 0.7299269 , 0.94477723, 0.69045974, 0.74166052],
[0.88628025, 0.60683977, 0.92004049, 0.6681216 , 0.74912106],
[0.92016968, 0.72731507, 0.95312994, 0.71067421, 0.74435376],
[0.89153894, 0.7269534 , 0.91180176, 0.66412323, 0.69702023],
[0.86218174, 0.69748519, 0.88348444, 0.64528768, 0.69289504],
[0.83977159, 0.64702101, 0.90570374, 1. , 0.76141677],
[0.87502531, 0.77516809, 0.94678984, 0.99046219, 0.73774107],
[0.8844526 , 0.70355008, 0.9683284 , 0.70287128, 1. ],
[0.8602876 , 0.74091914, 0.95035124, 0.68701354, 0.7476676 ],
[0.76355655, 0.69889255, 0.9493824 , 0.67843138, 0.73420885],
[0.7901021 , 0.68294115, 0.93079101, 0.65052509, 0.72646694],
[0.85631384, 0.58153043, 0.91895295, 0.65359395, 0.72017541],
[0.79693722, 0.65884406, 0.91044285, 0.63378575, 0.71484109],
[0.85490538, 0.81837957, 0.92542905, 0.63730305, 0.72158845],
[0.81261504, 0.54305411, 0.90957051, 0.63041697, 0.69289638],
[0.75565618, 0.57013129, 0.88020204, 0.61195024, 0.69682573],
[0.75562061, 0.58785546, 0.84200194, 0.59125534, 0.69380045],
[0.69595476, 0.49859798, 0.81922079, 0.56901918, 0.6589422 ],
[0.72429661, 0.64105178, 0.82452562, 0.57361443, 0.66191755],
[0.68284248, 0.57691049, 0.84095446, 0.58399711, 0.68510037],
[0.68145763, 0.62386684, 0.85011777, 0.59048639, 0.66685851],
[0.71356247, 0.51880793, 0.84734968, 0.58393118, 0.67277693]]
I am looking for help troubleshooting. Is there an initialization I should try? Should I try scaling the data differently (MinMaxScaler causes none of the time series to converge when run independently)? Thank you for your help.