Prior Predictive Density

Hello,

The tutorial Bayesian Regression Using NumPyro — NumPyro documentation shows how to obtain the Posterior Predictive Density:

def log_likelihood(rng_key, params, model, *args, **kwargs):
    model = handlers.condition(model, params)
    model_trace = handlers.trace(model).get_trace(*args, **kwargs)
    obs_node = model_trace["obs"]
    return obs_node["fn"].log_prob(obs_node["value"])


def log_pred_density(rng_key, params, model, *args, **kwargs):
    n = list(params.values())[0].shape[0]
    log_lk_fn = vmap(
        lambda rng_key, params: log_likelihood(rng_key, params, model, *args, **kwargs)
    )
    log_lk_vals = log_lk_fn(random.split(rng_key, n), params)
    return (logsumexp(log_lk_vals, 0) - jnp.log(n)).sum()

rng_key, rng_key_ = random.split(rng_key)
print(
    "Log posterior predictive density: {}".format(
        log_pred_density(
            rng_key_,
            samples_1,
            model,
            marriage=dset.MarriageScaled.values,
            divorce=dset.DivorceScaled.values,
        )
    )
)

How would one find the prior predictive density, i.e. \log \Pi_{i=1}^n \int p(y_i|\theta)p(\theta)d\theta, where p(\theta) denotes the prior?

Possible solution that I want to check:

Suppose I have a model, below I take an AR(1) for concreteness:

def ar1(x,y=None):
    # x = first lag of y

    sigma2_y = numpyro.sample("sigma2_y", dist.HalfNormal(scale=2))
    phi = numpyro.sample("phi", dist.TruncatedNormal(0,1,low=-1,high=1))

    numpyro.sample("obs", dist.Normal(loc=phi*x, scale=sigma2_y**0.5),obs=y)

I run the sampler:

rng_key = random.PRNGKey(0)
kernel = NUTS(ar1,max_tree_depth=20, target_accept_prob=0.95)
mcmc = MCMC(kernel, num_chains=3,num_warmup=2000, num_samples=2000)
mcmc.run(rng_key,y=y,x=x)
mcmc.print_summary()

From the Bayesian regression example, I can calculate the posterior predictive density as follows:

log_posterior_predictive_density = numpyro.infer.log_likelihood(model=ar1,
                             posterior_samples=mcmc.get_samples(),
                             y=y, x=x)

(logsumexp(log_posterior_predictive_density['obs'],0)-jnp.log(6000)).sum()

Can I get the prior predictive density by sampling from prior predictive, dropping the samples of “obs”, and then plugging those samples into the posterior_samples arg, as follows:

from numpyro.infer import Predictive

rng_key = random.PRNGKey(1)
rng_key, rng_key_ = random.split(rng_key)
prior_predictive = Predictive(ar1, num_samples=500)
prior_predictions = prior_predictive(rng_key_,x=x)

del prior_predictions['obs']

log_prior_predictive_density = numpyro.infer.log_likelihood(model=ar1,
                             posterior_samples=prior_predictions,
                             y=y, x=x)

(logsumexp(log_prior_predictive_density['obs'],0)-jnp.log(500)).sum()