# Posterior Prediction from Custom Likelihood Model

Hi all,

I have a model with custom likelihood function.
Is it possible to use Runtime utils `Predictive` to draw sample for posterior predictive tasks? i.e. to pass it later to arviz?

Here is a simple example. Assume we don’t have normal distribution likelihood and instead of `numpyro.sample(..., obs=X)` I used the `numpyro.factor`.

The inference is working but I am not sure is it even possible to use `Predictive` without writing `sample` function (complete `Distribution` class).

``````import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import numpyro
import numpyro.distributions as dist
from dataclasses import dataclass

from jax import lax, random
from numpyro.distributions.distribution import Distribution
from numpyro.infer import MCMC, NUTS, Predictive

@dataclass
class args:
seed = 12345678
num_data = 1000
mu = 1.5
sigma = 1.2
num_samples = 1000
num_warmup = 1000
num_chains = 4
device = "cpu"

numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)

def dist_numpy(rng, num_data: int = 100, mu: float = 0.0, sigma: float = 1.0):
"""
Generate data from normal distribution N(mu, sigma).
"""

return rng.normal(loc=mu, scale=sigma, size=num_data)

def log_prob(loc, scale, value):
"""
Copied from NumPyro GitHub source.
"""
normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * scale)
value_scaled = (value - loc) / scale
return -0.5 * value_scaled**2 - normalize_term

def model(X=None):
loc = numpyro.sample("loc", dist.Normal())
scale = numpyro.sample("scale", dist.Normal())
if X is not None:
log_density = log_prob(loc=loc, scale=scale, value=X)
numpyro.factor("custom_logp", log_density)

# RNG for numpy
rng_numpy = np.random.default_rng(args.seed)
# RNG for jax
rng_trace, rng_prior, rng_post = random.split(random.PRNGKey(args.seed), 3)
# Generate data
X = dist_numpy(
rng_numpy, num_data=args.num_data, mu=args.mu, sigma=args.sigma
)
# Get the model priors for loc and scale
model_prior = Predictive(model, num_samples=100)(rng_prior)
# Run inference
mcmc = MCMC(
NUTS(model),
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=True
)
mcmc.run(rng_trace, X=X)
mcmc.print_summary()

print(f"Designed values for mu={args.mu}, sigma={args.sigma}")
``````
1 Like

To draw samples from a custom likelihood, I guess you can use NUTS? Something like

``````mcmc = MCMC(NUTS(potential_fn=lambda x: -custom_log_prob(x)), ...)
mcmc.run(...)
``````

Predictive only works with sample statements by running the model forward. It does not provide mechanism to draw samples from some density (except for `infer_discrete`).

@fehiepsi

Thanks. Maybe I have not asked clearly what I want. Imagine we have:

• Custom likelihood as a blackbox function.
• We don’t have a prediction (noise model) function explicitly. For example I don’t know `y=ax+b` in linear regression, but I have a function which spit out for me the likelihood given `(x_data, y_data, a, b, noise)`.
• After trained phase finished, how can I find the prediction for `y` for given `x_test`.

Here, is an example of linear regression with custom likelihood. After the first `mcmc` finished, and I got my model parameters, I want to predict `y` for `x_test = 0.1`. Based on your suggestion, I wrote something which could be totally wrong:

``````import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from dataclasses import dataclass
from jax.scipy.stats import norm
from jax import  random
from numpyro.distributions.distribution import Distribution
from numpyro.infer import MCMC, NUTS, Predictive

@dataclass
class args:
seed = 12345678
num_data = 1000
intercept = 1.5
slope = -1.2
noise_scale = 0.5
num_samples = 1000
num_warmup = 2000
num_chains = 4
device = "cpu"

numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)

def gen_data(rng, num_data: int,
intercept:float, slope:float, noise_scale:float):
X = np.linspace(0, 1, num_data)
# y = a + b*x
regression_line = intercept + slope * X
y = regression_line + rng.normal(scale=0.25, size=num_data)
return X, y
def logp(X, y, a, b, sigma):
"""The likelihood function for a linear model
y ~ ax+b+error
"""
y_hat = a * X + b  # BUT WE DON'T KNOW IT in my case and function returns L for me.
L = jnp.sum(jnp.log(norm.pdf(y - y_hat, loc = 0, scale=sigma)))
return L

def model(X=None, y=None):
a = numpyro.sample("a", dist.Normal())
b = numpyro.sample("b", dist.Normal())
sigma = numpyro.sample("sigma", dist.HalfNormal())
if X is not None:
log_density = logp(X=X, y=y, a=a, b=b, sigma=sigma)
numpyro.factor("custom_logp", log_density)

# RNG for numpy
rng_numpy = np.random.default_rng(args.seed)
# RNG for jax
rng_trace, rng_prior, rng_post = random.split(random.PRNGKey(args.seed), 3)
# Generate data

X, y = gen_data(rng_numpy, num_data=args.num_data,
intercept=args.intercept, slope=args.slope,
noise_scale=args.noise_scale)
designed_label = f"y = {args.slope}x + {args.intercept} + N(0, {args.noise_scale})"

# Run inference
mcmc = MCMC(
NUTS(model),
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=True
)
mcmc.run(rng_trace, X=X, y=y)
posterior_samples = mcmc.get_samples()
mcmc.print_summary()
print(f"designed: {designed_label}")

## MY QUESTION IS HERE
x_test = 0.1
print("\n predictive mcmc")
# Run predictive
mcmc_pred = MCMC(
NUTS(potential_fn=lambda y : -logp(X=x_test, y=y,
a = posterior_samples["a"],
b = posterior_samples["b"],
sigma = posterior_samples["sigma"])),
num_warmup=args.num_warmup,
num_samples=args.num_samples,
num_chains=args.num_chains,
progress_bar=True
)
mcmc_pred.run(rng_trace, init_params=jnp.repeat(0., args.num_chains))
mcmc_pred.print_summary()
print(f"This should be around y = {args.slope*x_test+args.intercept} for x_test={x_test}")``````

Your code looks good to me. A better potential_fn might return

``````-logsumexp(jax.vmap(logp, (None, None, 0, 0, 0))(...)) - jnp.log(num_samples)
``````

i.e. vmapping over posterior samples a, b sigma to get a batch of logp. Then compute its log mean exp.

1 Like