Avoiding `scan` performance cost in time series with exogenous regressors

I have time series data with observations that I’d like to forecast and accompanying exogenous regressors. If I ignore the regressors and just build a static model which looks at only the observed data, it runs very quickly:

def static( observed):

    length, width = observed.shape

    correlation_chol = ny.sample("correlation_chol", dist.LKJCholesky(width, concentration=1))
    variances = ny.sample("variances", dist.HalfCauchy(scale=1).expand((width,)))
    cov_chol = ny.deterministic("cov_chol", jnp.sqrt(variances)[..., None] * correlation_chol)

    tau = ny.sample("tau", dist.HalfNormal(scale=1))
    with reparam(config={"loc": TransformReparam()}):
        loc = ny.sample(
            "loc",
            dist.TransformedDistribution(
                dist.Normal(loc=0, scale=1).expand((width,)),
                dist.transforms.LowerCholeskyAffine(loc=jnp.zeros(width), scale_tril=cov_chol * tau),
            ),
        )

    with ny.plate("data", length):
        ny.sample("observed", dist.MultivariateNormal(loc=loc, scale_tril=cov_chol), obs=observed)

The only way I can think to include the regressors in the model is to use a scan. But this scan is sort of overkill because there’s no actual dependence of later data points on earlier data points:

def exogenous(observed, exog):
    length, n_endog = observed.shape
    _, n_exog = exog.shape

    correlation_chol = ny.sample("correlation_chol", dist.LKJCholesky(n_endog, concentration=1))
    variances = ny.sample("variances", dist.HalfCauchy(scale=1).expand((n_endog,)))
    cov_chol = ny.deterministic("cov_chol", jnp.sqrt(variances)[..., None] * correlation_chol)

    tau = ny.sample("tau", dist.HalfNormal(scale=1))
    with reparam(config={"loc": TransformReparam()}):
        loc = ny.sample(
            "loc",
            dist.TransformedDistribution(
                dist.Normal(loc=0, scale=1).expand((n_endog,)),
                dist.transforms.LowerCholeskyAffine(loc=jnp.zeros(n_endog), scale_tril=cov_chol * tau),
            ),
        )

    loc_coeff = ny.sample("loc_coeff", dist.Normal(loc=0, scale=1).expand((n_endog, n_exog)))

    def inner(state, pos):
        loc_delta = loc_coeff @ exog[pos - 1, :]
        ny.sample("observed", dist.MultivariateNormal(loc=loc + loc_delta, scale_tril=cov_chol))
        return None, None

    with ny.handlers.condition(data={"observed": observed[1:]}):
        scan(inner, None, jnp.arange(1, length))

And the scan seems to impose a substantial performance cost (somewhat understated in the notebook here compared to my full application code: gratuitous-scan.ipynb · GitHub). Is there a way to phrase this without using scan?

what is the loc of observed supposed to be in math notation? i don’t see why you can’t e.g. use einsum to compute what you need and use a batched MultivariateNormal distribution

Ah, yes, batched MultivariateNormal is just what I wanted. Thanks so much. I must have had some other incidental error when I originally made an inchoate attempt at it.

Ended up cutting runtime by a factor of 4.