 # Speed up SVI inference of high dimensional linear regression

Hi Pyro community,

I’m working on a Pyro implementation of a high-dimensional (P = ~25,000, N = 12,000) regression, using the horseshoe distribution to induce sparsity. I’ve verified that the model works on toy examples. On real data, SVI iterations take ~350 seconds per iteration on CPU. I’m using a slight modification of a multivariate normal guide.

Here’s the code of the model:

``````def model(x, y, group_indicator, P, G, N):
tau = pyro.sample("tau", dist.HalfCauchy(torch.ones(1)))

with pyro.plate("lam_plate", P):
lam = pyro.sample("lam", dist.HalfCauchy(torch.ones(P)))

with pyro.plate("weights_plate", P):
weights = pyro.sample("weights", dist.Normal(torch.zeros(P), torch.ones(P) * lam * tau))

with pyro.plate("random_effect_plate", G):
intercepts = pyro.sample("intercepts", dist.Normal(torch.zeros(G), torch.ones(G)))

with pyro.plate("data", N):
y = pyro.sample("y", dist.Normal(torch.mv(x, weights) + intercepts[group_indicator], torch.ones(1)), obs = y)`
``````

Here’s the code of the guide:

``````def guide(x, y, group_indicator, P, G, N):
tau_loc = pyro.param("tau_scale", torch.rand(1))
tau_scale = pyro.param("tau_scale", torch.rand(1), constraint = constraints.positive)
lam_loc = pyro.param("lam_loc", torch.rand([P]))
lam_scale = pyro.param("lam_scale", torch.eye(P), constraint = constraints.lower_cholesky)
weights_loc = pyro.param("weights_loc", torch.zeros([P]))
weights_scale = pyro.param("weights_scale", torch.eye(P), constraint = constraints.lower_cholesky)

random_intercepts_loc = pyro.param("random_intercepts_loc", torch.zeros([G]))
random_intercepts_scale = pyro.param("random_intercepts_scale", torch.ones([G]), constraint = constraints.positive)

q_tau = pyro.sample("tau", dist.TransformedDistribution(
dist.Normal(tau_loc, tau_scale),
transforms = transforms.ExpTransform()
))
q_lam = pyro.sample("lam", dist.TransformedDistribution(
dist.MultivariateNormal(lam_loc, scale_tril = lam_scale),
transforms = transforms.ExpTransform()
))

q_weights = pyro.sample("weights", dist.MultivariateNormal(weights_loc, scale_tril = weights_scale))

q_random_intercepts = pyro.sample("intercepts", dist.Normal(random_intercepts_loc, random_intercepts_scale))
``````

And SVI code:

``````optim = pyro.optim.Adam({"lr" : 0.01})
svi = pyro.infer.SVI(model, guide, optim, loss = pyro.infer.Trace_ELBO(20))
pyro.clear_param_store()
for j in tqdm(range(num_iterations)):
loss = svi.step(x, y, group_indicator, P, G, N)
``````

Is there anything that I can do to dramatically speed this up besides switching from CPU->GPU?

Thanks for developing pyro,
Josh

1 Like

one thing that should help is doing data subsampling.

another thing you might try is using fewer samples in `Trace_ELBO`. you’ll probably need more iterations but the overall time to convergence will likely go down. in this context you may also want to tune the `Adam` hyperparameters more carefully

Thank you for the reply! Thanks, sub-sampling 100 rows at a time improves speed to 80s / iteration, though that is still quite slow. Not sure how sub-sampling will affect total time to convergence. I definitely need to tune the Adam parameters but assume that is unrelated to the speed per iteration.

unfortunately that’s quite a large `P`

another thing you could do that would probably give you large speed-ups is to demote your multivariate normal distributions to diagonal normal distributions or a `LowRankMultivariateNormal`
(i suspect much of your computation is coming from there)