Censored log likelihood with low count

I’m trying to fit a censored log likelihood poisson model but with low count data and Im running into issues using SVI.

It seems like when the counts get low, the model loses the ability to recover the true parameters. I dont have this problem with MCMC using pymc, but I do have this problem using MCMC with numpyro. Maybe its the way I specify the likelihood contribution from censored data? The numpyro way to do it (my censored_poisson function below) feels weird, maybe its just wrong?

Here’s reproducible code

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


import numpyro
from numpyro.infer import Predictive, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoMultivariateNormal
import numpyro.distributions as dist
from numpyro.handlers import mask

import jax
from jax import random
import jax.scipy.special as jsp
import jax.numpy as jnp


def simulate(mu, N=1000):
    y = np.random.poisson(mu,size=N)
    stock = np.random.choice(range(10), p=[0.4, 0.2, 0.1, 0.1, 0.05, 0.05,0.025,0.025, 0.025,0.025], size=N)
    y_obs =  np.clip(y, 0, stock)
    return pd.DataFrame({"y":y_obs, "stock":stock, "cens":np.where(stock<=y, 1,0)})



def fit_model(model, inputs, step_size=0.001, training_samples=2500, loss=Trace_ELBO()):
    auto_guide = AutoMultivariateNormal(
        model, 
        init_loc_fn=numpyro.infer.init_to_median())

    optimizer = numpyro.optim.Adam(step_size=step_size) # 0.005
    svi = numpyro.infer.SVI(
        model, 
        auto_guide, #hierarchical_guide, 
        optimizer, loss=loss)

    svi_result = svi.run(random.PRNGKey(0), training_samples, *inputs)
    params = svi_result.params

    # get posterior samples
    predictive = Predictive(auto_guide, params=params, num_samples=1000)
    samples = predictive(random.PRNGKey(1), *inputs)
    plt.plot(svi_result.losses); plt.title("ELBO Loss")
    return svi_result, params, samples, auto_guide


# Model
def censored_poisson(mu, cens, y):
    observed_mask = (cens != 1)
    censored_mask = (cens == 1)
    
    numpyro.sample("obs", dist.Poisson(mu).mask(observed_mask), obs=y)

    # # Censored
    censored_prob = 1 - dist.Poisson(mu).cdf(y)
    numpyro.sample("censored_label", dist.Bernoulli(censored_prob).mask(censored_mask), obs=cens)
    
def model(cens, y=None ):
    
    lambd = numpyro.sample("lambd", dist.Gamma(1,0.1) )
    censored_poisson(lambd, cens=cens, y=y)
    

# Simulation

MU_TRUE = 0.05
data = simulate(mu=MU_TRUE, N=1000)

inputs = [
    data.cens.values,
    data.y.values
]

svi_result, params, samples, guide = fit_model(model, inputs, training_samples=10000, step_size=0.001)

# should be 0.05
print( samples['lambd'].mean() )

This was a bit of a math mistake on my end - since Im using a discrete observational model (poisson distribution), the censored probability is actually the pmf + the complementary cumulative density funciton. Adjusting my censored poisson function does the trick to recover the true parameters (although Im sure there’s an easy math optimization to improve this)

def censored_poisson(mu, cens, y):
    observed_mask = (cens != 1)
    censored_mask = (cens == 1)
    
    numpyro.sample("obs", dist.Poisson(mu).mask(observed_mask), obs=y)

    # # Censored
    ccdf = 1 - dist.Poisson(mu).cdf(y) 
    pdf = jnp.exp( dist.Poisson(mu).log_prob(y)  )
    numpyro.sample("censored_label", dist.Bernoulli(ccdf+pdf).mask(censored_mask), obs=cens)