Hello. I’m new to Pyro and PyTorch. After working through the first few tutorials, I tried to write a very simple program that recovers the mean and standard deviation of a normally distributed toy data set, but it doesn’t compute the correct results. I was hoping that I could get some help.
My data set is just samples from a normal distribution with a mean of 10 and a standard deviation of 3:
def create_data(mu, sd):
data = torch.zeros(500)
for i in range(500):
data[i] = torch.distributions.Normal(mu, sd).sample()
return data
data = create_data(10, 3)
I’m familiar with Stan, so I wrote a Stan program which correctly recovers the mean and sd as values close to 10 and 3:
data {
int<lower=0> num_data;
vector[num_data] x;
}
parameters {
real mu0;
real<lower=0> sd0;
}
model {
mu0 ~ normal(0.0, 1.0);
sd0 ~ gamma(1.0, 1.0);
x ~ normal(mu0, sd0);
}
This is the Pyro program I wrote to mimic what the Stan program does.
def model(data):
mu0 = pyro.sample("latent_mu0", dist.Normal(0.0, 1.0))
sd0 = pyro.sample("latent_sd0", dist.Gamma(1.0, 1.0))
with pyro.plate("observed_data"):
pyro.sample("obs", dist.Normal(mu0, sd0), obs=data)
def guide(data):
mu0_mu_q = pyro.param("mu0_mu_q", torch.tensor(0.0), constraint=constraints.real)
mu0_sd_q = pyro.param("mu0_sd_q", torch.tensor(1.0), constraint=constraints.positive)
sd0_alpha_q = pyro.param("sd0_alpha_q", torch.tensor(1.0), constraint=constraints.positive)
sd0_beta_q = pyro.param("sd0_beta_q", torch.tensor(1.0), constraint=constraints.positive)
pyro.sample("latent_mu0", dist.Normal(mu0_mu_q, mu0_sd_q))
pyro.sample("latent_sd0", dist.Gamma(sd0_alpha_q, sd0_beta_q))
def train(n_steps, data):
pyro.clear_param_store()
adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
for _ in range(n_steps):
svi.step(data)
No matter how much data I give the model, it always returns parameters that are basically the same as their initialization values in the guide. What am I doing wrong? What is the best way to write the guide function in this situation?
A quick note - I’m not a data scientist or statistician (I’m a game designer) and have a tenuous understanding of this stuff. I apologize if this is overly basic. Thanks for your help!