Fitting models with NUTS is slow

I am comparing a mixture of a two component univariate gaussian fit by stan vs pyro (via NUTS). I recognize that Stan uses a derivation of NUTS, but Stan is way faster, ~4 seconds vs ~70 seconds with Pyro.

Stan model

library(rstan)

set.seed(1)

n <- 1000
K <- 2
alpha <- c(0.3, 0.7)
mu <- c(3, 7)
sigma <- 0.1

z <- sample(length(alpha), size=n, replace=TRUE, prob=alpha)
data <- rnorm(n, mu[z], sigma)

stancode <- "
data {
    int N;
    int K;
    vector[N] y;

    // priors
    vector<lower=0>[K] mu0;
    real<lower=0> sigma0;

    real<lower=0> alpha0;
    real<lower=0> beta0;

    vector<lower=0>[K] pi0;
}

parameters {
    vector[K] mu;
    simplex[K] theta;
    real<lower=0> sigma;
}

model {
    real contributions[K];

    mu ~ normal(mu0, sigma0);
    theta ~ dirichlet(pi0);
    sigma ~ gamma(alpha0, beta0);

    for (n in 1:N) {
        for (k in 1:K) {
            contributions[k] = log(theta[k]) + normal_lpdf(y[n] | mu[k], sigma);
        }

        target += log_sum_exp(contributions);
    }
}
"

stan_data <- list(N=length(data),
                  K=length(alpha),
                  y=data,
                  mu0=c(3, 7),
                  sigma0=0.1,
                  pi0=c(3, 7),
                  alpha0=1,
                  beta0=1)

set.seed(1)
results <- stan(model_code=stancode,
                data=stan_data,
                chains=4, iter=2000,
                cores=4,
                seed=1)

pyro model

from numpy.random import choice, normal
from numpy import array
from torch.distributions.normal import Normal
from torch import tensor

from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc import NUTS
from pyro.infer import config_enumerate
from pyro import distributions as dist
from pyro import plate

import pyro


pyro.set_rng_seed(1)

n = 1000
K = 2
alpha = array([0.3, 0.7])
mu = array([3, 7])
sigma = 0.1

z = choice(len(alpha), size=n, p=alpha)
data = tensor(normal(mu[z], sigma, n)).float()

@config_enumerate
def model(data, mu0, sigma0, pi0, alpha0, beta0):
    # priors
    mu = pyro.sample("mu", dist.Normal(mu0, sigma0).to_event(1))
    theta = pyro.sample("theta", dist.Dirichlet(pi0))
    sigma = pyro.sample("sigma", dist.Gamma(alpha0, beta0))
     
    with pyro.plate('data', len(data)):
        Z = pyro.sample('Z', dist.Categorical(theta))
        pyro.sample('obs', dist.Normal(mu[Z], sigma), obs=data)

kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=1000, warmup_steps=1000, num_chains=4)
mcmc.run(data, tensor([3., 7.]), 0.1, tensor([3., 7.]), 1., 1.)

Am I doing something wrong in pyro model that is causing a slowdown? Is Stan just better tuned for MCMC than Pyro for this type of model?

pyro HMC/NUTS can be pretty slow for small models. this is for technical reasons that are primarily driven by pytorch. if you want fast HMC/NUTS for small/medium sized models try numpyro. in many cases it seems to be quite a bit faster than stan.

1 Like