Hello Pyro community! I’m a Pyro noob, just getting my feet wet with some simple models, and ran into a warning that I could use some help understanding.
I’m attempting to use NUTS to draw from the posterior of a simple Poisson GLM with one covariate and n observations:
y_i ~ Poisson(exp(a + b * x_i)) for i = 1, …, n
a ~ Normal(0, 2)
b ~ Normal(0, 1)
Here’s the code I have so far:
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import MCMC, NUTS
from pyro.infer import EmpiricalMarginal
def simulate_data(n, alpha, beta1):
"""
Simulate data from a Poisson GLM
n: number of observations
alpha: intercept
beta1: slope
"""
x = np.linspace(start=0, stop=1, num=n)
log_lam = alpha + beta1 * x
lam = np.exp(log_lam)
counts = np.random.poisson(lam)
out = {'n': n,
'alpha': alpha,
'beta1': beta1,
'x': x.astype('float32'),
'lambda': lam,
'counts': counts.astype('float32')}
return out
data = simulate_data(n=40, alpha=3, beta1=-1.5)
x = torch.tensor(data['x'])
counts = torch.tensor(data['counts'])
def model():
alpha = pyro.sample('alpha', dist.Normal(loc=0, scale=2))
beta1 = pyro.sample('beta1', dist.Normal(loc=0, scale=1))
log_lam = alpha + beta1 * x
lam = torch.exp(log_lam)
with pyro.iarange('observe_data'):
pyro.sample('obs', dist.Poisson(rate=lam), obs=counts)
kernel = NUTS(model, adapt_step_size=True)
num_samples = 1000
warmup_steps = 1000
mcmc = MCMC(kernel, num_samples, warmup_steps).run()
This last line returns the following warning:
/home/max/anaconda3/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site 'obs'
warnings.warn("Encountered NAN log_prob_sum at site '{}'".format(name))
I’m not sure why the joint log probability of the observed counts would be NAN, but any insight into what’s causing this warning, and whether there’s a fix would be much appreciated. Apologies if this is obvious to others, I’m coming mainly from R and Stan so I’m doubly out of my depth here.
Thanks in advance.