I’m trying to follow this tutorial for implementing an MAP model in Pyro for a linear regression. My code looks like this:
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from torch.distributions import constraints
torch.set_default_tensor_type(torch.FloatTensor)
class MAP_LR:
def __init__(self, n_iter=1000):
self.n_iter_ = n_iter
def fit(self, x, y):
self.train_point_estimate(self.model_MAP,
self.guide_MAP,
x,
y)
##################################################################
def train_point_estimate(self, model, guide, x, y, lr=0.1):
pyro.clear_param_store()
adam = pyro.optim.Adam({"lr": lr})
self.svi = SVI(model, guide, adam, loss=Trace_ELBO())
print_interval = min(self.n_iter_//10, 10000)
for step in range(self.n_iter_):
loss = self.svi.step(x, y)/y.shape[0]
if (step % print_interval == 0):
print(f"{step} - loss: {loss}")
print(f"{step} - loss: {loss}")
##################################################################
def model_MAP(self, x, y):
beta_0 = pyro.sample("beta_0", dist.Normal(torch.tensor(0.0),
torch.tensor(1.0)))
beta_1 = pyro.sample("beta_1", dist.Normal(torch.tensor(0.0),
torch.tensor(1.0)))
sigma = pyro.sample("sigma",
dist.HalfCauchy(torch.tensor(1.0)))
mu = beta_0 + (beta_1 * x)
with pyro.plate("data", y.size(0)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
def guide_MAP(self, x, y):
beta_0_val = pyro.param("beta_0_val", torch.tensor(0.0))
pyro.sample("beta_0", dist.Delta(beta_0_val))
beta_1_val = pyro.param("beta_1_val", torch.tensor(0.0))
pyro.sample("beta_1", dist.Delta(beta_1_val))
sigma_val = pyro.param("sigma_val", torch.tensor(10.0),
constraint = constraints.positive)
pyro.sample("sigma", dist.Delta(sigma_val))
This works fine, but if I change one of my beta_
pyro.param
statements to start with torch.tensor(10.0)
instead of torch.tensor(0.0)
I get nan
s for the loss value.
0 - loss: 8.826700978114876e+35
100 - loss: nan
/anaconda3/envs/pyroenv/lib/python3.8/site-packages/pyro/infer/trace_elbo.py:138: UserWarning: Encountered NaN: loss
warn_if_nan(loss, "loss")
200 - loss: nan
300 - loss: nan
400 - loss: nan
500 - loss: nan
600 - loss: nan
700 - loss: nan
800 - loss: nan
900 - loss: nan
999 - loss: nan
Additionally, if I change the prior for beta_1
from dist.Normal(torch.tensor(0.0), torch.tensor(1.0))
to dist.Normal(torch.tensor(100.0), torch.tensor(1.0))
, my final estimate for beta_1_val
does not change at all.
I tried changing the default tensor type, doesn’t help
my x
and y
have 2905 values, why is this behaving so strange?