Speed up SVI inference of high dimensional linear regression

Hi Pyro community,

I’m working on a Pyro implementation of a high-dimensional (P = ~25,000, N = 12,000) regression, using the horseshoe distribution to induce sparsity. I’ve verified that the model works on toy examples. On real data, SVI iterations take ~350 seconds per iteration on CPU. I’m using a slight modification of a multivariate normal guide.

Here’s the code of the model:

def model(x, y, group_indicator, P, G, N):
    tau = pyro.sample("tau", dist.HalfCauchy(torch.ones(1)))
       
    with pyro.plate("lam_plate", P):
        lam = pyro.sample("lam", dist.HalfCauchy(torch.ones(P)))

    with pyro.plate("weights_plate", P):
        weights = pyro.sample("weights", dist.Normal(torch.zeros(P), torch.ones(P) * lam * tau))

    with pyro.plate("random_effect_plate", G):
        intercepts = pyro.sample("intercepts", dist.Normal(torch.zeros(G), torch.ones(G))) 

    with pyro.plate("data", N):
         y = pyro.sample("y", dist.Normal(torch.mv(x, weights) + intercepts[group_indicator], torch.ones(1)), obs = y)`

Here’s the code of the guide:

def guide(x, y, group_indicator, P, G, N):
    tau_loc = pyro.param("tau_scale", torch.rand(1))
    tau_scale = pyro.param("tau_scale", torch.rand(1), constraint = constraints.positive)
    lam_loc = pyro.param("lam_loc", torch.rand([P]))
    lam_scale = pyro.param("lam_scale", torch.eye(P), constraint = constraints.lower_cholesky)
    weights_loc = pyro.param("weights_loc", torch.zeros([P]))
    weights_scale = pyro.param("weights_scale", torch.eye(P), constraint = constraints.lower_cholesky)

    random_intercepts_loc = pyro.param("random_intercepts_loc", torch.zeros([G]))
    random_intercepts_scale = pyro.param("random_intercepts_scale", torch.ones([G]), constraint = constraints.positive)

    q_tau = pyro.sample("tau", dist.TransformedDistribution(
                dist.Normal(tau_loc, tau_scale),
                transforms = transforms.ExpTransform()
                ))
    q_lam = pyro.sample("lam", dist.TransformedDistribution(
                dist.MultivariateNormal(lam_loc, scale_tril = lam_scale),
                transforms = transforms.ExpTransform()
                ))

    q_weights = pyro.sample("weights", dist.MultivariateNormal(weights_loc, scale_tril = weights_scale))

    q_random_intercepts = pyro.sample("intercepts", dist.Normal(random_intercepts_loc, random_intercepts_scale))

And SVI code:

optim = pyro.optim.Adam({"lr" : 0.01})
svi = pyro.infer.SVI(model, guide, optim, loss = pyro.infer.Trace_ELBO(20))
pyro.clear_param_store()
for j in tqdm(range(num_iterations)):
    loss = svi.step(x, y, group_indicator, P, G, N)

Is there anything that I can do to dramatically speed this up besides switching from CPU->GPU?

Thanks for developing pyro,
Josh

1 Like

one thing that should help is doing data subsampling.

another thing you might try is using fewer samples in Trace_ELBO. you’ll probably need more iterations but the overall time to convergence will likely go down. in this context you may also want to tune the Adam hyperparameters more carefully

Thank you for the reply! Thanks, sub-sampling 100 rows at a time improves speed to 80s / iteration, though that is still quite slow. Not sure how sub-sampling will affect total time to convergence. I definitely need to tune the Adam parameters but assume that is unrelated to the speed per iteration.

unfortunately that’s quite a large P

another thing you could do that would probably give you large speed-ups is to demote your multivariate normal distributions to diagonal normal distributions or a LowRankMultivariateNormal
(i suspect much of your computation is coming from there)