Log_density gives finite values outside expected support

Hi everyone, I’m having trouble understanding how numpyro.infer.util.log_density works. I am working with models with finite support due to Uniform distribution, but the log_density function yield values that are finite outside this support. See e.g. the following code :

import numpy as np 
import numpyro
import numpyro.distributions as dist

rng = np.random.default_rng(42)
observed_data = rng.normal(loc=3, scale=1, size=100)

def model(data):
    # Prior
    mu = numpyro.sample("mu", dist.Uniform(1, 5))
    sigma = numpyro.sample("sigma", dist.HalfNormal(10))

    # Likelihood
    with numpyro.plate("data", size=len(data)):
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)
        
print(numpyro.infer.util.log_density(model, (observed_data, ), dict(), {"mu":0.5, "sigma":1})[0])

which returns ~ -425

Am I missing something? Is the log_density function returning the posterior density estimated for a set of parameters in the constrained space?

TY very much for your help

You can run numpyro.enable_validation() to get the desired behavior. log_density assumes that the values are valid.

1 Like

TY ! Works as expected now. Unlike the documentation states, this works under JIT which feels weird, is it expected ?

import numpy as np 
import numpyro
import jax 
import numpyro.distributions as dist

numpyro.enable_validation()
rng = np.random.default_rng(42)
observed_data = rng.normal(loc=3, scale=1, size=100)

def model(data):
    # Prior
    mu = numpyro.sample("mu", dist.Uniform(1, 5))
    sigma = numpyro.sample("sigma", dist.HalfNormal(10))

    # Likelihood
    with numpyro.plate("data", size=len(data)):
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=data)
        
print(jax.jit(lambda pars : numpyro.infer.util.log_density(model, (observed_data, ), dict(), pars)[0])({"mu":0.5, "sigma":1}))

yes, it should work under jit. though the corresponding warnings (for out-of-support values) will not apppear under jit.

1 Like