Hi Pyro community,
I’m working on a Pyro implementation of a high-dimensional (P = ~25,000, N = 12,000) regression, using the horseshoe distribution to induce sparsity. I’ve verified that the model works on toy examples. On real data, SVI iterations take ~350 seconds per iteration on CPU. I’m using a slight modification of a multivariate normal guide.
Here’s the code of the model:
def model(x, y, group_indicator, P, G, N):
tau = pyro.sample("tau", dist.HalfCauchy(torch.ones(1)))
with pyro.plate("lam_plate", P):
lam = pyro.sample("lam", dist.HalfCauchy(torch.ones(P)))
with pyro.plate("weights_plate", P):
weights = pyro.sample("weights", dist.Normal(torch.zeros(P), torch.ones(P) * lam * tau))
with pyro.plate("random_effect_plate", G):
intercepts = pyro.sample("intercepts", dist.Normal(torch.zeros(G), torch.ones(G)))
with pyro.plate("data", N):
y = pyro.sample("y", dist.Normal(torch.mv(x, weights) + intercepts[group_indicator], torch.ones(1)), obs = y)`
Here’s the code of the guide:
def guide(x, y, group_indicator, P, G, N):
tau_loc = pyro.param("tau_scale", torch.rand(1))
tau_scale = pyro.param("tau_scale", torch.rand(1), constraint = constraints.positive)
lam_loc = pyro.param("lam_loc", torch.rand([P]))
lam_scale = pyro.param("lam_scale", torch.eye(P), constraint = constraints.lower_cholesky)
weights_loc = pyro.param("weights_loc", torch.zeros([P]))
weights_scale = pyro.param("weights_scale", torch.eye(P), constraint = constraints.lower_cholesky)
random_intercepts_loc = pyro.param("random_intercepts_loc", torch.zeros([G]))
random_intercepts_scale = pyro.param("random_intercepts_scale", torch.ones([G]), constraint = constraints.positive)
q_tau = pyro.sample("tau", dist.TransformedDistribution(
dist.Normal(tau_loc, tau_scale),
transforms = transforms.ExpTransform()
))
q_lam = pyro.sample("lam", dist.TransformedDistribution(
dist.MultivariateNormal(lam_loc, scale_tril = lam_scale),
transforms = transforms.ExpTransform()
))
q_weights = pyro.sample("weights", dist.MultivariateNormal(weights_loc, scale_tril = weights_scale))
q_random_intercepts = pyro.sample("intercepts", dist.Normal(random_intercepts_loc, random_intercepts_scale))
And SVI code:
optim = pyro.optim.Adam({"lr" : 0.01})
svi = pyro.infer.SVI(model, guide, optim, loss = pyro.infer.Trace_ELBO(20))
pyro.clear_param_store()
for j in tqdm(range(num_iterations)):
loss = svi.step(x, y, group_indicator, P, G, N)
Is there anything that I can do to dramatically speed this up besides switching from CPU->GPU?
Thanks for developing pyro,
Josh