I have time series data with observations that I’d like to forecast and accompanying exogenous regressors. If I ignore the regressors and just build a static model which looks at only the observed data, it runs very quickly:
def static( observed):
length, width = observed.shape
correlation_chol = ny.sample("correlation_chol", dist.LKJCholesky(width, concentration=1))
variances = ny.sample("variances", dist.HalfCauchy(scale=1).expand((width,)))
cov_chol = ny.deterministic("cov_chol", jnp.sqrt(variances)[..., None] * correlation_chol)
tau = ny.sample("tau", dist.HalfNormal(scale=1))
with reparam(config={"loc": TransformReparam()}):
loc = ny.sample(
"loc",
dist.TransformedDistribution(
dist.Normal(loc=0, scale=1).expand((width,)),
dist.transforms.LowerCholeskyAffine(loc=jnp.zeros(width), scale_tril=cov_chol * tau),
),
)
with ny.plate("data", length):
ny.sample("observed", dist.MultivariateNormal(loc=loc, scale_tril=cov_chol), obs=observed)
The only way I can think to include the regressors in the model is to use a scan
. But this scan
is sort of overkill because there’s no actual dependence of later data points on earlier data points:
def exogenous(observed, exog):
length, n_endog = observed.shape
_, n_exog = exog.shape
correlation_chol = ny.sample("correlation_chol", dist.LKJCholesky(n_endog, concentration=1))
variances = ny.sample("variances", dist.HalfCauchy(scale=1).expand((n_endog,)))
cov_chol = ny.deterministic("cov_chol", jnp.sqrt(variances)[..., None] * correlation_chol)
tau = ny.sample("tau", dist.HalfNormal(scale=1))
with reparam(config={"loc": TransformReparam()}):
loc = ny.sample(
"loc",
dist.TransformedDistribution(
dist.Normal(loc=0, scale=1).expand((n_endog,)),
dist.transforms.LowerCholeskyAffine(loc=jnp.zeros(n_endog), scale_tril=cov_chol * tau),
),
)
loc_coeff = ny.sample("loc_coeff", dist.Normal(loc=0, scale=1).expand((n_endog, n_exog)))
def inner(state, pos):
loc_delta = loc_coeff @ exog[pos - 1, :]
ny.sample("observed", dist.MultivariateNormal(loc=loc + loc_delta, scale_tril=cov_chol))
return None, None
with ny.handlers.condition(data={"observed": observed[1:]}):
scan(inner, None, jnp.arange(1, length))
And the scan
seems to impose a substantial performance cost (somewhat understated in the notebook here compared to my full application code: gratuitous-scan.ipynb · GitHub). Is there a way to phrase this without using scan
?