Hi all,
I have a model with custom likelihood function.
Is it possible to use Runtime utils Predictive
to draw sample for posterior predictive tasks? i.e. to pass it later to arviz?
Here is a simple example. Assume we don’t have normal distribution likelihood and instead of numpyro.sample(..., obs=X)
I used the numpyro.factor
.
The inference is working but I am not sure is it even possible to use Predictive
without writing sample
function (complete Distribution
class).
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import numpyro
import numpyro.distributions as dist
from dataclasses import dataclass
from jax import lax, random
from numpyro.distributions.distribution import Distribution
from numpyro.infer import MCMC, NUTS, Predictive
@dataclass
class args:
seed = 12345678
num_data = 1000
mu = 1.5
sigma = 1.2
num_samples = 1000
num_warmup = 1000
num_chains = 4
device = "cpu"
numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)
def dist_numpy(rng, num_data: int = 100, mu: float = 0.0, sigma: float = 1.0):
"""
Generate data from normal distribution N(mu, sigma).
"""
return rng.normal(loc=mu, scale=sigma, size=num_data)
def log_prob(loc, scale, value):
"""
Copied from NumPyro GitHub source.
"""
normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * scale)
value_scaled = (value - loc) / scale
return -0.5 * value_scaled**2 - normalize_term
def model(X=None):
loc = numpyro.sample("loc", dist.Normal())
scale = numpyro.sample("scale", dist.Normal())
if X is not None:
log_density = log_prob(loc=loc, scale=scale, value=X)
numpyro.factor("custom_logp", log_density)
# RNG for numpy
rng_numpy = np.random.default_rng(args.seed)
# RNG for jax
rng_trace, rng_prior, rng_post = random.split(random.PRNGKey(args.seed), 3)
# Generate data
X = dist_numpy(
rng_numpy, num_data=args.num_data, mu=args.mu, sigma=args.sigma
)
# Get the model priors for loc and scale
model_prior = Predictive(model, num_samples=100)(rng_prior)
# Run inference
mcmc = MCMC(
NUTS(model),
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=True
)
mcmc.run(rng_trace, X=X)
mcmc.print_summary()
print(f"Designed values for mu={args.mu}, sigma={args.sigma}")