I am comparing a mixture of a two component univariate gaussian fit by stan vs pyro (via NUTS). I recognize that Stan uses a derivation of NUTS, but Stan is way faster, ~4 seconds vs ~70 seconds with Pyro.
Stan model
library(rstan)
set.seed(1)
n <- 1000
K <- 2
alpha <- c(0.3, 0.7)
mu <- c(3, 7)
sigma <- 0.1
z <- sample(length(alpha), size=n, replace=TRUE, prob=alpha)
data <- rnorm(n, mu[z], sigma)
stancode <- "
data {
int N;
int K;
vector[N] y;
// priors
vector<lower=0>[K] mu0;
real<lower=0> sigma0;
real<lower=0> alpha0;
real<lower=0> beta0;
vector<lower=0>[K] pi0;
}
parameters {
vector[K] mu;
simplex[K] theta;
real<lower=0> sigma;
}
model {
real contributions[K];
mu ~ normal(mu0, sigma0);
theta ~ dirichlet(pi0);
sigma ~ gamma(alpha0, beta0);
for (n in 1:N) {
for (k in 1:K) {
contributions[k] = log(theta[k]) + normal_lpdf(y[n] | mu[k], sigma);
}
target += log_sum_exp(contributions);
}
}
"
stan_data <- list(N=length(data),
K=length(alpha),
y=data,
mu0=c(3, 7),
sigma0=0.1,
pi0=c(3, 7),
alpha0=1,
beta0=1)
set.seed(1)
results <- stan(model_code=stancode,
data=stan_data,
chains=4, iter=2000,
cores=4,
seed=1)
pyro model
from numpy.random import choice, normal
from numpy import array
from torch.distributions.normal import Normal
from torch import tensor
from pyro.infer.mcmc.api import MCMC
from pyro.infer.mcmc import NUTS
from pyro.infer import config_enumerate
from pyro import distributions as dist
from pyro import plate
import pyro
pyro.set_rng_seed(1)
n = 1000
K = 2
alpha = array([0.3, 0.7])
mu = array([3, 7])
sigma = 0.1
z = choice(len(alpha), size=n, p=alpha)
data = tensor(normal(mu[z], sigma, n)).float()
@config_enumerate
def model(data, mu0, sigma0, pi0, alpha0, beta0):
# priors
mu = pyro.sample("mu", dist.Normal(mu0, sigma0).to_event(1))
theta = pyro.sample("theta", dist.Dirichlet(pi0))
sigma = pyro.sample("sigma", dist.Gamma(alpha0, beta0))
with pyro.plate('data', len(data)):
Z = pyro.sample('Z', dist.Categorical(theta))
pyro.sample('obs', dist.Normal(mu[Z], sigma), obs=data)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_samples=1000, warmup_steps=1000, num_chains=4)
mcmc.run(data, tensor([3., 7.]), 0.1, tensor([3., 7.]), 1., 1.)
Am I doing something wrong in pyro model that is causing a slowdown? Is Stan just better tuned for MCMC than Pyro for this type of model?