I’m trying to make use of Pyro’s implementation of Stable distributions. As a first sanity check I’m trying to recover a set of known parameters via MLE, but I can’t get it to find the correct values. I have a minimum example:
import torch
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
from pyro.distributions import constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.reparam import MinimalReparam
device = "cuda"
# Define true parameters and number of datapoints
alpha = 1.4
beta = 0.8
c = 1.1
mu = 3
n = 1000000
# sample data
with torch.device(device):
data = dist.Stable(alpha, beta, c, mu).sample((n,))
pyro.clear_param_store()
# define a simple model
@MinimalReparam()
def simple_model(data):
alpha = pyro.param("alpha", torch.tensor(1.3), constraint=constraints.interval(0, 2))
beta = pyro.param("beta", torch.tensor(0.7), constraint=constraints.interval(-1, 1))
c = pyro.param("c", torch.tensor(1.0), constraint=constraints.positive)
mu = pyro.param("mu", torch.tensor(2.9), constraint=constraints.real)
with pyro.plate("data", data.shape[0]):
pyro.sample("obs", dist.Stable(alpha, beta, c, mu), obs=data)
# set up Autoguide, ELBO, and optimizer
with torch.device(device):
guide = pyro.infer.autoguide.AutoDelta(simple_model)
elbo = Trace_ELBO()
elbo.loss(simple_model, guide, data=data)
num_steps = 10001
with torch.device(device):
optim = pyro.optim.ClippedAdam({"lr": 0.01})
scheduler = pyro.optim.CosineAnnealingLR({"optimizer": optim, "optim_args": {"lr": 0.01, "T_max": num_steps}})
svi = SVI(simple_model, guide, optim, loss=elbo)
# optimize
losses = []
for i in range(num_steps):
loss = svi.step(data)
losses.append(loss)
print(f"Parameter estimates (n = {n}):")
print(f"alpha: Estimate = {pyro.param('alpha')}, true = {alpha}")
print(f"beta: Estimate = {pyro.param('beta')}, true = {beta}")
print(f"c: Estimate = {pyro.param('c')}, true = {c}")
print(f"mu: Estimate = {pyro.param('mu')}, true = {mu}")
Results:
Parameter estimates (n = 1000000):
alpha: Estimate = 0.3840053975582123, true = 1.4
beta: Estimate = 0.6137597560882568, true = 0.8
c: Estimate = 1.2401551008224487, true = 1.1
mu: Estimate = 2.395522117614746, true = 3
With 1M datapoints, SVI converges to a parameter set significantly worse than the initial guesses I gave to the pyro.param()
calls. What am I getting wrong here? Am I messing something up with the reparameterizer?