I’m trying to fit a censored log likelihood poisson model but with low count data and Im running into issues using SVI.
It seems like when the counts get low, the model loses the ability to recover the true parameters. I dont have this problem with MCMC using pymc, but I do have this problem using MCMC with numpyro. Maybe its the way I specify the likelihood contribution from censored data? The numpyro way to do it (my censored_poisson
function below) feels weird, maybe its just wrong?
Here’s reproducible code
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import numpyro
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoMultivariateNormal
import numpyro.distributions as dist
from numpyro.handlers import mask
import jax
from jax import random
import jax.scipy.special as jsp
import jax.numpy as jnp
def simulate(mu, N=1000):
y = np.random.poisson(mu,size=N)
stock = np.random.choice(range(10), p=[0.4, 0.2, 0.1, 0.1, 0.05, 0.05,0.025,0.025, 0.025,0.025], size=N)
y_obs = np.clip(y, 0, stock)
return pd.DataFrame({"y":y_obs, "stock":stock, "cens":np.where(stock<=y, 1,0)})
def fit_model(model, inputs, step_size=0.001, training_samples=2500, loss=Trace_ELBO()):
auto_guide = AutoMultivariateNormal(
model,
init_loc_fn=numpyro.infer.init_to_median())
optimizer = numpyro.optim.Adam(step_size=step_size) # 0.005
svi = numpyro.infer.SVI(
model,
auto_guide, #hierarchical_guide,
optimizer, loss=loss)
svi_result = svi.run(random.PRNGKey(0), training_samples, *inputs)
params = svi_result.params
# get posterior samples
predictive = Predictive(auto_guide, params=params, num_samples=1000)
samples = predictive(random.PRNGKey(1), *inputs)
plt.plot(svi_result.losses); plt.title("ELBO Loss")
return svi_result, params, samples, auto_guide
# Model
def censored_poisson(mu, cens, y):
observed_mask = (cens != 1)
censored_mask = (cens == 1)
numpyro.sample("obs", dist.Poisson(mu).mask(observed_mask), obs=y)
# # Censored
censored_prob = 1 - dist.Poisson(mu).cdf(y)
numpyro.sample("censored_label", dist.Bernoulli(censored_prob).mask(censored_mask), obs=cens)
def model(cens, y=None ):
lambd = numpyro.sample("lambd", dist.Gamma(1,0.1) )
censored_poisson(lambd, cens=cens, y=y)
# Simulation
MU_TRUE = 0.05
data = simulate(mu=MU_TRUE, N=1000)
inputs = [
data.cens.values,
data.y.values
]
svi_result, params, samples, guide = fit_model(model, inputs, training_samples=10000, step_size=0.001)
# should be 0.05
print( samples['lambd'].mean() )