Posterior Prediction from Custom Likelihood Model

Hi all,

I have a model with custom likelihood function.
Is it possible to use Runtime utils Predictive to draw sample for posterior predictive tasks? i.e. to pass it later to arviz?

Here is a simple example. Assume we don’t have normal distribution likelihood and instead of numpyro.sample(..., obs=X) I used the numpyro.factor.

The inference is working but I am not sure is it even possible to use Predictive without writing sample function (complete Distribution class).

import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import numpyro
import numpyro.distributions as dist
from dataclasses import dataclass

from jax import lax, random
from numpyro.distributions.distribution import Distribution
from numpyro.infer import MCMC, NUTS, Predictive

@dataclass
class args:
    seed = 12345678
    num_data = 1000
    mu = 1.5
    sigma = 1.2
    num_samples = 1000
    num_warmup = 1000
    num_chains = 4
    device = "cpu"

numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)

def dist_numpy(rng, num_data: int = 100, mu: float = 0.0, sigma: float = 1.0):
    """
    Generate data from normal distribution N(mu, sigma).
    """

    return rng.normal(loc=mu, scale=sigma, size=num_data)

def log_prob(loc, scale, value):
    """
    Copied from NumPyro GitHub source.
    """
    normalize_term = jnp.log(jnp.sqrt(2 * jnp.pi) * scale)
    value_scaled = (value - loc) / scale
    return -0.5 * value_scaled**2 - normalize_term


def model(X=None):
    loc = numpyro.sample("loc", dist.Normal())
    scale = numpyro.sample("scale", dist.Normal())
    if X is not None:
        log_density = log_prob(loc=loc, scale=scale, value=X)
        numpyro.factor("custom_logp", log_density)
        
# RNG for numpy
rng_numpy = np.random.default_rng(args.seed)
# RNG for jax
rng_trace, rng_prior, rng_post = random.split(random.PRNGKey(args.seed), 3)
# Generate data
X = dist_numpy(
    rng_numpy, num_data=args.num_data, mu=args.mu, sigma=args.sigma
)
# Get the model priors for loc and scale
model_prior = Predictive(model, num_samples=100)(rng_prior)
# Run inference
mcmc = MCMC(
    NUTS(model),
    num_warmup=args.num_warmup,
    num_samples=args.num_samples,
    num_chains=args.num_chains,
    progress_bar=True
)
mcmc.run(rng_trace, X=X)
mcmc.print_summary()

print(f"Designed values for mu={args.mu}, sigma={args.sigma}")
1 Like

To draw samples from a custom likelihood, I guess you can use NUTS? Something like

mcmc = MCMC(NUTS(potential_fn=lambda x: -custom_log_prob(x)), ...)
mcmc.run(...)

Predictive only works with sample statements by running the model forward. It does not provide mechanism to draw samples from some density (except for infer_discrete).

@fehiepsi

Thanks. Maybe I have not asked clearly what I want. Imagine we have:

  • Custom likelihood as a blackbox function.
  • We don’t have a prediction (noise model) function explicitly. For example I don’t know y=ax+b in linear regression, but I have a function which spit out for me the likelihood given (x_data, y_data, a, b, noise).
  • After trained phase finished, how can I find the prediction for y for given x_test.

Here, is an example of linear regression with custom likelihood. After the first mcmc finished, and I got my model parameters, I want to predict y for x_test = 0.1. Based on your suggestion, I wrote something which could be totally wrong:

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from dataclasses import dataclass
from jax.scipy.stats import norm
from jax import  random
from numpyro.distributions.distribution import Distribution
from numpyro.infer import MCMC, NUTS, Predictive

@dataclass
class args:
    seed = 12345678
    num_data = 1000
    intercept = 1.5
    slope = -1.2
    noise_scale = 0.5
    num_samples = 1000
    num_warmup = 2000
    num_chains = 4
    device = "cpu"

numpyro.set_platform(args.device)
numpyro.set_host_device_count(args.num_chains)

def gen_data(rng, num_data: int, 
             intercept:float, slope:float, noise_scale:float):
    X = np.linspace(0, 1, num_data)
    # y = a + b*x 
    regression_line = intercept + slope * X
    # add noise
    y = regression_line + rng.normal(scale=0.25, size=num_data)
    return X, y
def logp(X, y, a, b, sigma):
    """The likelihood function for a linear model
    y ~ ax+b+error
    """
    y_hat = a * X + b  # BUT WE DON'T KNOW IT in my case and function returns L for me.
    L = jnp.sum(jnp.log(norm.pdf(y - y_hat, loc = 0, scale=sigma)))
    return L

def model(X=None, y=None):
    a = numpyro.sample("a", dist.Normal())
    b = numpyro.sample("b", dist.Normal())
    sigma = numpyro.sample("sigma", dist.HalfNormal())
    if X is not None:
        log_density = logp(X=X, y=y, a=a, b=b, sigma=sigma)
        numpyro.factor("custom_logp", log_density)

# RNG for numpy
rng_numpy = np.random.default_rng(args.seed)
# RNG for jax
rng_trace, rng_prior, rng_post = random.split(random.PRNGKey(args.seed), 3)
# Generate data

X, y = gen_data(rng_numpy, num_data=args.num_data,
               intercept=args.intercept, slope=args.slope,
               noise_scale=args.noise_scale)
designed_label = f"y = {args.slope}x + {args.intercept} + N(0, {args.noise_scale})"

# Run inference
mcmc = MCMC(
    NUTS(model),
    num_warmup=args.num_warmup,
    num_samples=args.num_samples,
    num_chains=args.num_chains,
    progress_bar=True
)
mcmc.run(rng_trace, X=X, y=y)
posterior_samples = mcmc.get_samples()
mcmc.print_summary()
print(f"designed: {designed_label}")

## MY QUESTION IS HERE
x_test = 0.1
print("\n predictive mcmc")
# Run predictive
mcmc_pred = MCMC(
    NUTS(potential_fn=lambda y : -logp(X=x_test, y=y,
                                       a = posterior_samples["a"], 
                                       b = posterior_samples["b"],
                                       sigma = posterior_samples["sigma"])),
    num_warmup=args.num_warmup,
    num_samples=args.num_samples,
    num_chains=args.num_chains,
    progress_bar=True
)
mcmc_pred.run(rng_trace, init_params=jnp.repeat(0., args.num_chains))
mcmc_pred.print_summary()
print(f"This should be around y = {args.slope*x_test+args.intercept} for x_test={x_test}")

Your code looks good to me. A better potential_fn might return

-logsumexp(jax.vmap(logp, (None, None, 0, 0, 0))(...)) - jnp.log(num_samples)

i.e. vmapping over posterior samples a, b sigma to get a batch of logp. Then compute its log mean exp.

1 Like