Expanding SGT model to multiple time series

Thanks for the suggestion. I managed to solve the issue. However, I’m trying to expand the code to handle correlated time series. I came up with the following model, but I keep getting the error ValueError: Incompatible shapes for broadcasting: shapes=[(55, 5), (55, 55)] and I cannot find what exactly is producing this What I noticed is that y_ = numpyro.sample("y",dist.MultivariateStudentT(df=nu, loc=exp_val, scale_tril=jnp.linalg.cholesky(cov))) is producing a (55,5) array when I think it should be (5,). Is there an issue with the dist.MultivariateStudentT function?

Any ideas are welcome!

def sgt_multiple_correlated(y, seasonality, future=0):
    T, num_series = y.shape
    cauchy_sd = jnp.max(y) / 150
    
    # Define the prior for the correlation matrix
    L_omega = numpyro.sample("L_omega", dist.LKJCholesky(num_series, concentration=1.0))

    # Extract the correlation matrix from the Cholesky factor
    corr_matrix = jnp.matmul(L_omega, L_omega.T)
    
    with numpyro.plate("num_series", num_series):
        nu = numpyro.sample("nu", dist.Uniform(2, 20))
        powx = numpyro.sample("powx", dist.Uniform(0, 1))
        sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
        offset_sigma = numpyro.sample(
            "offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)
        )

        coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
        pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
        pow_trend = 1.5 * pow_trend_beta - 0.5
        pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))

        level_sm = numpyro.sample("level_sm", dist.Beta(1, 2))
        s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1))
        init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality, :] * 0.3))

    def transition_fn(carry, t):
        level, s, moving_sum = carry
        season = s[0] * level**pow_season
        exp_val = level + coef_trend * level**pow_trend + season
        exp_val = nn.softplus(exp_val)  # apply softplus transformation
        y_t = jnp.where(t >= T, exp_val, y[t, :])

        moving_sum = (
            moving_sum + y[t, :] - jnp.where(t >= seasonality, y[t - seasonality, :], 0.0)
        )
        level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)
        level = level_sm * level_p + (1 - level_sm) * level
        level = jnp.clip(level, a_min=0)

        new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
        new_s = jnp.where(t >= T, s[0], new_s)
        s = jnp.concatenate([s[1:], new_s[None]], axis=0)

        # Compute the covariance matrix using the Cholesky factor and the correlation matrix
        sigma_diag = jnp.diag(sigma * exp_val ** powx + offset_sigma)
        cov = jnp.dot(jnp.dot(L_omega, sigma_diag), L_omega.T)
        y_ = numpyro.sample("y", dist.MultivariateStudentT(df=nu, loc=exp_val, scale_tril=jnp.linalg.cholesky(cov)))
        return (level, s, moving_sum), y_[0,]

    level_init = y[0, :]
    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
    moving_sum = level_init
    with numpyro.handlers.condition(data={"y": y[1:, :]}):

        _, ys = scan(
            transition_fn,
            (level_init, s_init, moving_sum),
            jnp.arange(1, T + future),
        )
    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:, :])

# y is the same array shared in a previous post
kernel = NUTS(sgt_multiple_correlated)
mcmc = MCMC(kernel, num_warmup=5000, num_samples=5000, num_chains=4)
mcmc.run(random.PRNGKey(0), y, seasonality=12)
mcmc.print_summary()
samples = mcmc.get_samples()

EDIT: After experimenting a bit more with the code, I think there’s a bug with the dist.MultivariateStudentT. I opened an issue in the NumPyro GitHub account.