Simple linear MAP performance

I’m testing out some bayesian computation frameworks. Here’s a model I’ve written

def model(features: PyroFeatures, log_prices: torch.Tensor):
    fourier_coef = pyro.sample(
        pdt.Normal(torch.zeros(12), 0.25 * torch.ones(12)).to_event(1),
    category_coef = pyro.sample(
            2.0 * torch.ones(features.unique_categories),
    radius_coef = pyro.sample('radius_coef', pdt.LogNormal(-2.0, 2.0))
    log_position_coef = pyro.sample('log_position_coef', pdt.LogNormal(-2.3, 1.6))
    intercept = pyro.sample('intercept', pdt.Normal(1.7, 0.75))
    sigma = pyro.sample('sigma', pdt.LogNormal(-1.0, 2.0))
    log_expected_value = (
        torch.matmul(features.fourier, fourier_coef)
        + torch.matmul(features.category, category_coef)
        - features.radius * radius_coef
        - features.log_position * log_position_coef
        + intercept
    with pyro.plate("data", len(log_prices)):
        pyro.sample("obs", pdt.Normal(log_expected_value, sigma), obs=log_prices)

def guide(features: PyroFeatures, log_prices: torch.Tensor):
    fourier_coef_loc = pyro.param('fourier_coef_loc', torch.zeros(12))
    category_coef_loc = pyro.param(
    radius_coef_loc = pyro.param(
    log_position_coef_loc = pyro.param(
    intercept_loc = pyro.param('intercept_loc', torch.zeros(1))
    sigma_loc = pyro.param('sigma_loc', torch.ones(1), constraint=pdt.constraints.positive)

    fourier_coef = pyro.sample('fourier_coef', pdt.Delta(fourier_coef_loc).to_event(1))
    category_coef = pyro.sample('category_coef', pdt.Delta(category_coef_loc).to_event(1))
    radius_coef = pyro.sample('radius_coef', pdt.Delta(radius_coef_loc))
    log_position_coef = pyro.sample('log_position_coef', pdt.Delta(log_position_coef_loc))
    intercept = pyro.sample('intercept', pdt.Delta(intercept_loc))
    sigma = pyro.sample('sigma', pdt.Delta(sigma_loc))

which I’m fitting with the boilerplate

svi = pin.SVI(
    pot.Adam({"lr": .05}),
num_iters = 5000
last_n_iters = 50
all_elbos = []
elbos = []
for i in range(num_iters):
    elbo = svi.step(pyro_features, log_prices)
    if len(elbos) > last_n_iters:
        elbo_arr, all_elbo_arr = np.array(elbos), np.array(all_elbos)
        if any((
            elbo_arr.std() < 1e-4,
            np.mean(np.abs((elbo_arr[1:] - elbo_arr[:-1]) / elbo_arr[:-1])) < 1e-4,
    if i % 500 == 0:
        print("Elbo loss: {}".format(elbo))

I have 1981 datapoints. Fitting this map takes about 17 seconds and 2000 iterations. An identical model written in Stan takes about 0.8 seconds, and gives nearly identical/slightly better results. This is surprising, given that optimization is not emphasized in Stan and Pyro/Pytorch are built to handle neural networks with thousands of parameters on millions of datapoints.

Why does this take so long? Have I implemented this in a suboptimal way? My thoughts:

  1. Maybe my hand-rolled relative loss tolerance is causing the problems by rebuilding numpy arrays every loop? But it takes about the same amount of time if I remove the convergence checking and hard code the stopping iteration count.
  2. Maybe L-BFGS is just better than Adam on small datasets where the whole hessian can fit in memory? But it seems like the stan fitting is also faster per-iteration.
  3. Maybe the PyTorch computational graph just comes with a lot of fixed-cost overhead that makes small datasets and small models uneconomical?
  4. Maybe I’ve screwed this up somehow?

It’s almost certainly a combination of #2 and #3, and mostly #3. PyTorch is just not performant for small models, and if that’s the regime you’re interested in we recommend using our Jax backend NumPyro instead.

Two things you can try to speed things up are disabling validation (pyro.enable_validation(False)) to reduce Pyro overhead and using our PyTorch JIT-compiled ELBO pyro.infer.JitTrace_ELBO to eliminate Python overhead. However, most of the overhead is in PyTorch itself, so the effects will be limited.

1 Like

Thanks for the prompt response, will probably check out numpyro. Would you say that if Google commits fully to JAX as their flagship differentiable computing library, and it’s cleared for a 1.0 release, that NumPyro would become the main focus of this project?

Both Pyro and Numpyro are under active development and we have no plans to change that.