Expanding SGT model to multiple time series

I would like to use the Seasonal, global trend model, as documented here, with a hierarchical model for multiple time series.

I’ve started by allowing multiple time series to be passed in, but am not able to get the model to converge consistently. The model I am using:

def sgt(y, seasonality, future=0):
    # heuristically, standard derivation of Cauchy prior depends on
    # the max value of data
    cauchy_sd = jnp.max(y) / 150

    with numpyro.plate('plate_titles', y.shape[-1]):
        # NB: priors' parameters are taken from
        # https://github.com/cbergmeir/Rlgt/blob/master/Rlgt/R/rlgtcontrol.R
        nu = numpyro.sample("nu", dist.Uniform(2, 20))
        powx = numpyro.sample("powx", dist.Uniform(0, 1))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
        offset_sigma = numpyro.sample(
            "offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)
        )

        coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
        pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
        # pow_trend takes values from -0.5 to 1
        pow_trend = 1.5 * pow_trend_beta - 0.5
        pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))

        level_sm = numpyro.sample("level_sm", dist.Beta(1, 2))
        s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1))

    init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3))

    def transition_fn(carry, t):
        level, s, moving_sum = carry
        season = s[0] * level**pow_season
        exp_val = level + coef_trend * level**pow_trend + season
        exp_val = jnp.clip(exp_val, a_min=0)
        mu = numpyro.deterministic('mu', exp_val)
        # use expected vale when forecasting
        y_t = jnp.where(t >= N, mu, y[t])

        moving_sum = (
                moving_sum + y[t] - jnp.where(t >= seasonality, y[t - seasonality], 0.0)
        )
        level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)
        level = level_sm * level_p + (1 - level_sm) * level
        level = jnp.clip(level, a_min=0)

        new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
        # repeat s when forecasting
        new_s = jnp.where(t >= N, s[0], new_s)
        s = jnp.concatenate([s[1:], new_s[None]], axis=0)

        omega = sigma * exp_val**powx + offset_sigma
        y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega))

        return (level, s, moving_sum), y_

    N = y.shape[0]
    level_init = y[0]
    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    with numpyro.handlers.condition(data={"y": y[1:]}):
        _, ys = scan(
            transition_fn, (level_init, s_init, moving_sum), jnp.arange(1, N + future)
        )
    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:, ...])

Will not converge when I pass in 5 time series.

ts_train_scaled = ts_train / ts_train.max(axis=0)
y_train = jnp.array(ts_train_scaled.values, dtype=jnp.float32)

nuts_kernel = NUTS(sgt, target_accept_prob=0.95)

mcmc = MCMC(nuts_kernel, num_samples=3000, num_warmup=5000, num_chains=3)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, y_train, seasonality=12)

results in the mcmc.print_summary():

                       mean       std    median      5.0%     95.0%     n_eff     r_hat
    coef_trend[0]      0.36      1.12     -0.31     -0.55      1.94      1.50 112418.62
    coef_trend[1]     -0.16      1.08      0.41     -1.67      0.78      1.50 122840.10
    coef_trend[2]      0.88      1.09      1.37     -0.64      1.89      1.50  76501.25
    coef_trend[3]      0.46      0.57      0.62     -0.31      1.06      1.50  61631.89
    coef_trend[4]      0.74      0.86      0.30     -0.03      1.94      1.50  51004.46
      init_s[0,0]      0.02      1.46     -0.37     -1.54      1.97      1.50  89578.06
      init_s[0,1]     -0.76      0.31     -0.77     -1.13     -0.38      1.50  29751.61
      init_s[0,2]      0.84      0.78      0.36      0.23      1.95      1.50 120843.83
      init_s[0,3]     -0.16      0.69     -0.43     -0.83      0.79      1.50  70123.95
      init_s[0,4]      0.45      0.78      0.72     -0.60      1.25      1.50  71473.78
      init_s[1,0]      0.15      0.74     -0.36     -0.38      1.19       nan 110405.20
      init_s[1,1]      0.55      1.16      1.11     -1.06      1.61      1.50 124990.59
      init_s[1,2]      0.14      1.19      0.29     -1.39      1.52      1.50 212772.50
      init_s[1,3]     -0.27      1.08      0.18     -1.75      0.77      1.50 430631.81
      init_s[1,4]      0.77      0.71      0.95     -0.17      1.53      1.50  65490.65
      init_s[2,0]     -0.16      1.06      0.30     -1.62      0.85      1.50 170010.61
      init_s[2,1]     -0.04      0.81      0.24     -1.14      0.77      1.50 122270.51
      init_s[2,2]      0.71      0.49      0.79      0.07      1.27       nan  75560.48
      init_s[2,3]     -0.60      1.10     -1.36     -1.40      0.96      1.50 118743.43
      init_s[2,4]     -0.84      0.75     -0.55     -1.86     -0.11      1.50  65711.94
      init_s[3,0]     -0.07      0.98      0.40     -1.44      0.83      1.50 126212.38
      init_s[3,1]      0.01      1.17     -0.52     -1.08      1.63      1.50  66828.71
      init_s[3,2]     -0.54      1.27     -1.23     -1.64      1.24      1.50  99708.92
      init_s[3,3]     -0.54      1.45     -1.25     -1.85      1.47      1.50  95123.71
      init_s[3,4]      0.76      0.55      1.02     -0.01      1.27      1.50  84896.32
      init_s[4,0]      0.76      0.49      0.85      0.12      1.30      1.50  97484.64
      init_s[4,1]      1.25      0.70      1.71      0.26      1.78      1.50  60106.23
      init_s[4,2]      0.92      0.52      0.93      0.27      1.55      1.50  84627.20
      init_s[4,3]     -1.28      0.57     -1.43     -1.89     -0.51      1.50  55891.23
      init_s[4,4]     -0.34      1.30     -1.00     -1.50      1.47      1.50 245384.53
      init_s[5,0]     -0.06      0.91     -0.67     -0.73      1.23      1.50  84004.56
      init_s[5,1]      0.06      1.55      0.35     -1.96      1.79      1.50  92012.52
      init_s[5,2]      0.06      1.30     -0.62     -1.09      1.89      1.50  63699.32
      init_s[5,3]      0.33      0.62      0.28     -0.40      1.11      1.50  68776.95
      init_s[5,4]      1.14      0.79      1.42      0.07      1.95      1.50 160233.02
      init_s[6,0]      1.38      0.39      1.18      1.03      1.93      1.50  30507.22
      init_s[6,1]      0.13      1.18     -0.47     -0.91      1.78      1.50  58611.97
      init_s[6,2]     -1.06      0.66     -1.36     -1.67     -0.15      1.50  67687.60
      init_s[6,3]      0.86      0.80      0.78     -0.08      1.87      1.50 103087.18
      init_s[6,4]      0.38      0.78      0.08     -0.38      1.45      1.50 269876.09
      init_s[7,0]     -0.15      1.28      0.47     -1.93      1.01      1.50 143029.36
      init_s[7,1]      0.17      0.77     -0.13     -0.57      1.23      1.50  76967.12
      init_s[7,2]     -0.24      0.32     -0.31     -0.60      0.18      1.50  50018.43
      init_s[7,3]     -1.41      0.52     -1.77     -1.78     -0.67      1.50 125770.87
      init_s[7,4]     -0.94      0.91     -1.17     -1.93      0.28      1.50  62509.69
      init_s[8,0]     -0.56      0.89     -0.65     -1.60      0.58      1.50 101498.48
      init_s[8,1]      0.59      0.56      0.91     -0.19      1.05      1.50  54177.80
      init_s[8,2]     -0.14      0.75     -0.15     -1.05      0.78      1.50  64354.41
      init_s[8,3]      1.46      0.62      1.84      0.58      1.95      1.50  37548.76
      init_s[8,4]     -0.07      1.15      0.28     -1.62      1.14      1.50 167859.58
      init_s[9,0]     -0.42      1.09     -0.41     -1.76      0.92      1.50 172441.55
      init_s[9,1]      0.74      0.74      0.91     -0.25      1.55      1.50  88531.73
      init_s[9,2]      0.90      0.87      1.37     -0.33      1.65      1.50 113046.64
      init_s[9,3]     -0.33      0.75     -0.40     -1.21      0.61      1.50  89287.06
      init_s[9,4]     -0.29      0.29     -0.35     -0.61      0.10      1.50  51703.10
     init_s[10,0]      0.61      1.18      1.28     -1.06      1.60      1.50 102308.13
     init_s[10,1]      0.19      1.54      0.73     -1.91      1.76      1.50 187399.03
     init_s[10,2]      0.34      1.08      0.60     -1.10      1.50      1.50 108870.59
     init_s[10,3]     -1.41      0.57     -1.74     -1.88     -0.61      1.50  52140.90
     init_s[10,4]     -0.23      1.08      0.22     -1.73      0.81      1.50 112729.67
     init_s[11,0]      1.49      0.34      1.65      1.02      1.80      1.50  20130.48
     init_s[11,1]     -0.52      0.85     -0.19     -1.69      0.32      1.50 182218.47
     init_s[11,2]      0.38      1.38      0.57     -1.39      1.97      1.50 227164.64
     init_s[11,3]      0.68      0.85      0.20     -0.05      1.87      1.50 174068.81
     init_s[11,4]     -0.52      1.05     -0.97     -1.53      0.94      1.50 108386.41
      level_sm[0]      0.55      0.27      0.64      0.18      0.83      1.50  70060.18
      level_sm[1]      0.60      0.17      0.66      0.36      0.77      1.50  39629.94
      level_sm[2]      0.64      0.34      0.87      0.15      0.88      1.50  48925.53
      level_sm[3]      0.66      0.24      0.83      0.32      0.83      1.50  69822.33
      level_sm[4]      0.55      0.12      0.47      0.47      0.72      1.50  47265.89
            nu[0]     14.82      2.41     16.12     11.44     16.89      1.50  25090.27
            nu[1]     13.65      3.78     15.37      8.41     17.18      1.50  19929.25
            nu[2]      8.23      2.96      6.43      5.87     12.40      1.50  30653.68
            nu[3]      8.50      4.62      6.30      4.27     14.93      1.50  69428.32
            nu[4]     12.01      3.89      9.55      8.96     17.50      1.50  27090.38
  offset_sigma[0]      0.16      0.02      0.15      0.14      0.18      1.50   7675.23
  offset_sigma[1]      3.10      2.93      1.99      0.20      7.12      1.50 126687.43
  offset_sigma[2]      0.62      0.30      0.67      0.23      0.97      1.50  47580.12
  offset_sigma[3]      3.29      2.38      3.49      0.27      6.10      1.50  57718.11
  offset_sigma[4]      2.65      1.72      3.54      0.24      4.17      1.50  62632.24
    pow_season[0]      0.51      0.25      0.60      0.16      0.75      1.50  82478.45
    pow_season[1]      0.51      0.21      0.44      0.30      0.79      1.50  32873.70
    pow_season[2]      0.51      0.17      0.43      0.35      0.75      1.50  50344.23
    pow_season[3]      0.56      0.22      0.46      0.35      0.87      1.50  70952.66
    pow_season[4]      0.62      0.27      0.76      0.24      0.86      1.50  57755.46
pow_trend_beta[0]      0.24      0.12      0.18      0.13      0.41      1.50  42623.52
pow_trend_beta[1]      0.65      0.08      0.62      0.56      0.76      1.50  15979.00
pow_trend_beta[2]      0.35      0.10      0.31      0.25      0.49      1.50  41666.16
pow_trend_beta[3]      0.32      0.16      0.24      0.19      0.54      1.50  35616.77
pow_trend_beta[4]      0.35      0.14      0.36      0.17      0.51      1.50  27855.96
          powx[0]      0.54      0.22      0.39      0.39      0.85      1.50  79356.06
          powx[1]      0.44      0.29      0.38      0.12      0.82      1.50  58651.75
          powx[2]      0.36      0.17      0.43      0.13      0.53      1.50  46921.92
          powx[3]      0.66      0.27      0.82      0.27      0.88      1.50  49579.72
          powx[4]      0.41      0.30      0.28      0.13      0.82      1.50  64655.02
          s_sm[0]      0.56      0.27      0.58      0.21      0.88      1.50  38808.21
          s_sm[1]      0.46      0.25      0.42      0.18      0.78      1.50  43660.77
          s_sm[2]      0.68      0.27      0.86      0.30      0.87      1.50  44946.68
          s_sm[3]      0.64      0.08      0.62      0.56      0.74      1.50   9107.75
          s_sm[4]      0.52      0.21      0.47      0.29      0.79      1.50  23846.79
         sigma[0]      1.53      0.34      1.35      1.24      2.01      1.50  18716.86
         sigma[1]      2.62      3.12      0.65      0.17      7.03      1.50 151841.22
         sigma[2]      1.62      1.48      1.05      0.17      3.65      1.50 155257.16
         sigma[3]      1.51      1.83      0.28      0.15      4.10      1.50 619547.94
         sigma[4]      1.02      1.17      0.22      0.16      2.68      1.50  74382.05

Number of divergences: 9000

However, when I run with 1 or 2 time series (I’ve confirmed that each of the 5 time series converges independently), the model converges:

mcmc.run(rng_key, y_train[..., :2], seasonality=12)

Results in:

                       mean       std    median      5.0%     95.0%     n_eff     r_hat
    coef_trend[0]     -0.02      0.02     -0.02     -0.05      0.00   4132.93      1.00
    coef_trend[1]     -0.00      0.01     -0.00     -0.02      0.02   4955.30      1.00
      init_s[0,0]      0.09      0.08      0.09     -0.04      0.22   6795.17      1.00
      init_s[0,1]     -0.07      0.06     -0.07     -0.17      0.03   8379.66      1.00
      init_s[1,0]      0.03      0.05      0.03     -0.06      0.12   8749.67      1.00
      init_s[1,1]     -0.02      0.05     -0.02     -0.10      0.06   8393.98      1.00
      init_s[2,0]      0.06      0.05      0.06     -0.03      0.14   6817.32      1.00
      init_s[2,1]     -0.02      0.05     -0.02     -0.10      0.07   7933.21      1.00
      init_s[3,0]     -0.01      0.06     -0.01     -0.10      0.09   7448.22      1.00
      init_s[3,1]     -0.09      0.07     -0.10     -0.21      0.02   5832.02      1.00
      init_s[4,0]      0.04      0.05      0.04     -0.05      0.12   7174.37      1.00
      init_s[4,1]      0.05      0.07      0.05     -0.05      0.17   6700.69      1.00
      init_s[5,0]      0.06      0.05      0.06     -0.03      0.14   8191.24      1.00
      init_s[5,1]      0.02      0.06      0.01     -0.09      0.12   8123.27      1.00
      init_s[6,0]      0.06      0.05      0.06     -0.02      0.14   6699.31      1.00
      init_s[6,1]      0.05      0.05      0.04     -0.04      0.13   8430.82      1.00
      init_s[7,0]      0.05      0.05      0.06     -0.03      0.14   6489.25      1.00
      init_s[7,1]     -0.03      0.05     -0.03     -0.12      0.05   7872.68      1.00
      init_s[8,0]      0.03      0.05      0.03     -0.06      0.12   6549.39      1.00
      init_s[8,1]     -0.02      0.06     -0.02     -0.11      0.07  10253.20      1.00
      init_s[9,0]     -0.05      0.06     -0.05     -0.16      0.05   7296.68      1.00
      init_s[9,1]     -0.05      0.06     -0.05     -0.14      0.04   8898.04      1.00
     init_s[10,0]      0.00      0.06      0.00     -0.10      0.10   8708.00      1.00
     init_s[10,1]     -0.08      0.06     -0.08     -0.17      0.02   7703.48      1.00
     init_s[11,0]      0.01      0.06      0.01     -0.08      0.11   6926.03      1.00
     init_s[11,1]     -0.04      0.07     -0.04     -0.15      0.06   7641.93      1.00
      level_sm[0]      0.66      0.16      0.67      0.41      0.92   7396.50      1.00
      level_sm[1]      0.49      0.22      0.50      0.16      0.89   3047.35      1.00
            nu[0]     10.83      4.96     10.58      3.29     18.84  11557.78      1.00
            nu[1]      6.78      4.60      5.05      2.00     14.36   5167.74      1.00
  offset_sigma[0]      0.03      0.02      0.04      0.00      0.06   4553.84      1.00
  offset_sigma[1]      0.04      0.03      0.04      0.00      0.08   3769.03      1.00
    pow_season[0]      0.57      0.28      0.61      0.14      1.00  13224.98      1.00
    pow_season[1]      0.74      0.22      0.80      0.42      1.00   8854.54      1.00
pow_trend_beta[0]      0.51      0.29      0.52      0.10      1.00  13193.83      1.00
pow_trend_beta[1]      0.51      0.29      0.52      0.09      0.99  15865.08      1.00
          powx[0]      0.46      0.29      0.45      0.00      0.88  16280.81      1.00
          powx[1]      0.47      0.29      0.46      0.00      0.89  14675.20      1.00
          s_sm[0]      0.68      0.17      0.69      0.45      0.97   4970.09      1.00
          s_sm[1]      0.08      0.09      0.06      0.00      0.19   5599.77      1.00
         sigma[0]      0.03      0.02      0.02      0.00      0.06   4467.75      1.00
         sigma[1]      0.04      0.04      0.03      0.00      0.09   3527.14      1.00

Number of divergences: 0

Here is the data for the 5 time series:

[[0.93371051, 0.64300408, 0.95567756, 0.6979538 , 0.76896892],
       [0.87502375, 0.62234911, 0.9340615 , 0.65774621, 0.74398795],
       [0.89842731, 0.66488988, 0.9345194 , 0.67799256, 0.75796578],
       [0.78622806, 0.55800573, 0.93064537, 0.68454134, 0.75630942],
       [0.82913281, 0.74113209, 0.94648535, 0.71271235, 0.76568955],
       [0.88351516, 0.73015649, 0.95547244, 0.7193235 , 0.78208796],
       [0.89374491, 0.68741038, 0.95505926, 0.72206039, 0.78485916],
       [0.89743164, 0.61631777, 0.95341337, 0.71953237, 0.78877879],
       [0.85370598, 0.64081714, 0.94064708, 0.70682795, 0.79324588],
       [0.74987143, 0.62753388, 0.95383012, 0.69583601, 0.78315475],
       [0.91864621, 0.56334882, 0.94146711, 0.69433016, 0.76187362],
       [0.94027137, 0.6762718 , 0.94586243, 0.69045188, 0.78483132],
       [1.        , 0.66178637, 0.99895423, 0.71463336, 0.82113993],
       [0.975537  , 0.6785507 , 0.98172716, 0.71503008, 0.78962486],
       [0.91650188, 0.67191343, 0.97854702, 0.71913111, 0.80434578],
       [0.95943085, 0.88703244, 0.9726632 , 0.71317424, 0.80464693],
       [0.95272826, 0.67341532, 0.97707868, 0.7345105 , 0.81360763],
       [0.94321012, 0.64264604, 0.97915435, 0.73943241, 0.80511251],
       [0.95053464, 0.7325226 , 0.97722587, 0.73517754, 0.8020118 ],
       [0.94462703, 0.68118902, 0.97556803, 0.73248145, 0.80133707],
       [0.92150266, 0.60473359, 0.95902804, 0.71087999, 0.77705044],
       [0.94592757, 0.66135273, 0.97174026, 0.70898213, 0.78093387],
       [0.90923205, 0.59334388, 0.95620248, 0.70364048, 0.75438812],
       [0.9369869 , 0.58406081, 0.96382202, 0.71188881, 0.78686109],
       [0.98934096, 0.59010947, 1.        , 0.73600896, 0.82589099],
       [0.88738814, 0.70992145, 0.98016696, 0.71784463, 0.78524167],
       [0.89901598, 0.68150995, 0.98362083, 0.72482206, 0.80548564],
       [0.77925889, 0.66745375, 0.98152118, 0.70067359, 0.80250356],
       [0.98022655, 1.        , 0.9835584 , 0.72209406, 0.79753561],
       [0.93937041, 0.96445179, 0.94316288, 0.7064905 , 0.75250679],
       [0.95481498, 0.81383263, 0.96098556, 0.71576171, 0.77762344],
       [0.90698278, 0.70355008, 0.97850512, 0.73540242, 0.75413136],
       [0.92627557, 0.74518375, 0.97465047, 0.71159379, 0.76630128],
       [0.93453102, 0.68191003, 0.9681339 , 0.70612449, 0.77044084],
       [0.90323614, 0.7299269 , 0.94477723, 0.69045974, 0.74166052],
       [0.88628025, 0.60683977, 0.92004049, 0.6681216 , 0.74912106],
       [0.92016968, 0.72731507, 0.95312994, 0.71067421, 0.74435376],
       [0.89153894, 0.7269534 , 0.91180176, 0.66412323, 0.69702023],
       [0.86218174, 0.69748519, 0.88348444, 0.64528768, 0.69289504],
       [0.83977159, 0.64702101, 0.90570374, 1.        , 0.76141677],
       [0.87502531, 0.77516809, 0.94678984, 0.99046219, 0.73774107],
       [0.8844526 , 0.70355008, 0.9683284 , 0.70287128, 1.        ],
       [0.8602876 , 0.74091914, 0.95035124, 0.68701354, 0.7476676 ],
       [0.76355655, 0.69889255, 0.9493824 , 0.67843138, 0.73420885],
       [0.7901021 , 0.68294115, 0.93079101, 0.65052509, 0.72646694],
       [0.85631384, 0.58153043, 0.91895295, 0.65359395, 0.72017541],
       [0.79693722, 0.65884406, 0.91044285, 0.63378575, 0.71484109],
       [0.85490538, 0.81837957, 0.92542905, 0.63730305, 0.72158845],
       [0.81261504, 0.54305411, 0.90957051, 0.63041697, 0.69289638],
       [0.75565618, 0.57013129, 0.88020204, 0.61195024, 0.69682573],
       [0.75562061, 0.58785546, 0.84200194, 0.59125534, 0.69380045],
       [0.69595476, 0.49859798, 0.81922079, 0.56901918, 0.6589422 ],
       [0.72429661, 0.64105178, 0.82452562, 0.57361443, 0.66191755],
       [0.68284248, 0.57691049, 0.84095446, 0.58399711, 0.68510037],
       [0.68145763, 0.62386684, 0.85011777, 0.59048639, 0.66685851],
       [0.71356247, 0.51880793, 0.84734968, 0.58393118, 0.67277693]]

I am looking for help troubleshooting. Is there an initialization I should try? Should I try scaling the data differently (MinMaxScaler causes none of the time series to converge when run independently)? Thank you for your help.

1 Like

Interesting, IIUC those time series are independent. Could you try

with numpyro.plate('plate_titles', y.shape[-1]):
    ...
    with numpyro.plate("init", ..., dim=-2):
        init_s = numpyro.sample("init_s", ...)

def transition_fn(carry, t):
    ...
    # use to_event or add `plate_titles` here.
    y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega).to_event(1))

That worked, thank you! I also needed to scale the data by multiplying it by 10,000, which is on the same magnitude as the example.

You are right; the time series are independent in this example. I had removed the partial pooling while trying to troubleshoot. I will add it back in, confirm it works, and reply here with the partially pooled model for others to reference.

Thanks again.

1 Like

Hi! I’ve a question related to this problem. In case I wanna model the time series as correlated instead of being independent from each other, is there a way to do so?

I guess a multivariate t-Student distribution should do the work together with a LKJCholesky prior to model the correlation matrix, I just dunno how to incorporate these components.

Any ideas would be appreciated.

I’m trying to replicate the proposed solution, but I’m getting the following error: ValueError: Incompatible shapes for broadcasting: shapes=[(5, 1), (12, 5)]

Am I missing something?

The full code that I’m using is the following:

def sgt_multiple(y, seasonality, future=0):
    # heuristically, standard derivation of Cauchy prior depends on
    # the max value of data
    cauchy_sd = jnp.max(y) / 150

    with numpyro.plate('plate_titles', y.shape[-1]):
        # NB: priors' parameters are taken from
        # https://github.com/cbergmeir/Rlgt/blob/master/Rlgt/R/rlgtcontrol.R
        nu = numpyro.sample("nu", dist.Uniform(2, 20))
        powx = numpyro.sample("powx", dist.Uniform(0, 1))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
        offset_sigma = numpyro.sample(
            "offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)
        )

        coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
        pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
        # pow_trend takes values from -0.5 to 1
        pow_trend = 1.5 * pow_trend_beta - 0.5
        pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))

        level_sm = numpyro.sample("level_sm", dist.Beta(1, 2))
        s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1))
        
        with numpyro.plate("init", y.shape[-1], dim=-2):
            init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3))

    def transition_fn(carry, t):
        level, s, moving_sum = carry
        season = s[0] * level**pow_season
        exp_val = level + coef_trend * level**pow_trend + season
        exp_val = jnp.clip(exp_val, a_min=0)
        mu = numpyro.deterministic('mu', exp_val)
        # use expected vale when forecasting
        y_t = jnp.where(t >= N, mu, y[t])

        moving_sum = (
                moving_sum + y[t] - jnp.where(t >= seasonality, y[t - seasonality], 0.0)
        )
        level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)
        level = level_sm * level_p + (1 - level_sm) * level
        level = jnp.clip(level, a_min=0)

        new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
        # repeat s when forecasting
        new_s = jnp.where(t >= N, s[0], new_s)
        s = jnp.concatenate([s[1:], new_s[None]], axis=0)

        omega = sigma * exp_val**powx + offset_sigma
        y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega).to_event(1))

        return (level, s, moving_sum), y_

    N = y.shape[0]
    level_init = y[0]
    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    with numpyro.handlers.condition(data={"y": y[1:]}):
        _, ys = scan(
            transition_fn, (level_init, s_init, moving_sum), jnp.arange(1, N + future)
        )
    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:, ...])

y = jnp.array(...)

kernel = NUTS(sgt_multiple)
mcmc = MCMC(kernel, num_warmup=5000, num_samples=5000, num_chains=4)
mcmc.run(random.PRNGKey(0), y, seasonality=12)
mcmc.print_summary()
samples = mcmc.get_samples()

I think you can add many print shape statements to make sure that every variable has expected shape.

Thanks for the suggestion. I managed to solve the issue. However, I’m trying to expand the code to handle correlated time series. I came up with the following model, but I keep getting the error ValueError: Incompatible shapes for broadcasting: shapes=[(55, 5), (55, 55)] and I cannot find what exactly is producing this What I noticed is that y_ = numpyro.sample("y",dist.MultivariateStudentT(df=nu, loc=exp_val, scale_tril=jnp.linalg.cholesky(cov))) is producing a (55,5) array when I think it should be (5,). Is there an issue with the dist.MultivariateStudentT function?

Any ideas are welcome!

def sgt_multiple_correlated(y, seasonality, future=0):
    T, num_series = y.shape
    cauchy_sd = jnp.max(y) / 150
    
    # Define the prior for the correlation matrix
    L_omega = numpyro.sample("L_omega", dist.LKJCholesky(num_series, concentration=1.0))

    # Extract the correlation matrix from the Cholesky factor
    corr_matrix = jnp.matmul(L_omega, L_omega.T)
    
    with numpyro.plate("num_series", num_series):
        nu = numpyro.sample("nu", dist.Uniform(2, 20))
        powx = numpyro.sample("powx", dist.Uniform(0, 1))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
        offset_sigma = numpyro.sample(
            "offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)
        )

        coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
        pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
        pow_trend = 1.5 * pow_trend_beta - 0.5
        pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))

        level_sm = numpyro.sample("level_sm", dist.Beta(1, 2))
        s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1))
        init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality, :] * 0.3))

    def transition_fn(carry, t):
        level, s, moving_sum = carry
        season = s[0] * level**pow_season
        exp_val = level + coef_trend * level**pow_trend + season
        exp_val = nn.softplus(exp_val)  # apply softplus transformation
        y_t = jnp.where(t >= T, exp_val, y[t, :])

        moving_sum = (
            moving_sum + y[t, :] - jnp.where(t >= seasonality, y[t - seasonality, :], 0.0)
        )
        level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)
        level = level_sm * level_p + (1 - level_sm) * level
        level = jnp.clip(level, a_min=0)

        new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
        new_s = jnp.where(t >= T, s[0], new_s)
        s = jnp.concatenate([s[1:], new_s[None]], axis=0)

        # Compute the covariance matrix using the Cholesky factor and the correlation matrix
        sigma_diag = jnp.diag(sigma * exp_val ** powx + offset_sigma)
        cov = jnp.dot(jnp.dot(L_omega, sigma_diag), L_omega.T)
        y_ = numpyro.sample("y", dist.MultivariateStudentT(df=nu, loc=exp_val, scale_tril=jnp.linalg.cholesky(cov)))
        return (level, s, moving_sum), y_[0,]

    level_init = y[0, :]
    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    with numpyro.handlers.condition(data={"y": y[1:, :]}):

        _, ys = scan(
            transition_fn,
            (level_init, s_init, moving_sum),
            jnp.arange(1, T + future),
        )
    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:, :])

# y is the same array shared in a previous post
kernel = NUTS(sgt_multiple_correlated)
mcmc = MCMC(kernel, num_warmup=5000, num_samples=5000, num_chains=4)
mcmc.run(random.PRNGKey(0), y, seasonality=12)
mcmc.print_summary()
samples = mcmc.get_samples()

EDIT: After experimenting a bit more with the code, I think there’s a bug with the dist.MultivariateStudentT. I opened an issue in the NumPyro GitHub account.

What are your expected shapes of MVT parameters nu, loc, scale_tril? I added print statements and saw that they are

nu: (5,)
loc: (5,)
scale_tril: (5, 5)

Is it expected that nu has the shape (5,)? (note that MVT distribution uses a scalar, up to batching, degree of freedom across the event/correlated dimension) This MTV distribution will have batch_shape (5,) event shape (5,). While MVN distribution have batch_shape () event shape (5,).

Thanks for bringing that up. I removed the nu variable from the numpyro.plate context and put it outside so it’s a scalar. The code works without issues and divergences. Thanks!