Dear All,
I know this must be a stupid bug, but I cannot find what triggers this issue. I have a Planck black body radiation model that should be relatively straightforward to sample.
import jax.numpy as jnp
import jax
@jax.jit
def bb_flux(
λnm: "nanometers",
T: "K",
amp: float = 1.) -> jnp.array:
"""
default units blackbody as a flux distribution as a function of wavelength,
temperature and amplitude.
:param lam: wavelength in nm
:param amp: dimensionless normalization factor
:param teff: temperature in Kelvins
:return: evaluation of the blackbody radiation in flam units (erg/s/cm2/AA)
"""
# Constants
h = 6.626e-34 # Planck's constant [J*s = m**2 * kg / s]
c = 2.998e8 # Speed of light [m/s]
k = 1.381e-23 # Boltzmann constant [J/K]
# Planck's law
λ = λnm * 1e-9
return amp * 2 * h * c ** 2/ (λ ** 5 * (jnp.exp(h * c/(λ * k * T)) - 1))
The following is one of the attempts to describe the model.
import numpyro
import numpyro.distributions as dist
def model(λm: jnp.array,
fobs: jnp.array = None):
""" Model of the fobs | fpred, λ, ω """
T = numpyro.sample('T', dist.Uniform(3000., 10000.))
log_amp = numpyro.sample('log_amp', dist.Uniform(-0.01, 0.01))
amp = jnp.power(10., log_amp)
fhat = numpyro.deterministic('fhat', bb_flux(λnm, T, amp))
residuals = jnp.nan_to_num(fobs - fhat)
numpyro.factor('obs', dist.Normal(0, 1e4).log_prob(residuals).sum())
In this model, the factor
could be replaced by numpyro.sample('obs', dist.normal(fhat, 1e4), obs=fobs)
.
(the issue remains)
The inference does not run because of an initialization error
λnm = np.linspace(100, 900, 1000)
teff_true = 4500 # _K
amp_true = 1.
fobs = amp_true * bb_flux(λnm, teff_true, amp_true)
from numpyro import infer
kernel = infer.NUTS(model, init_strategy=infer.init_to_median())
sampler = infer.MCMC(
kernel,
num_chains=1, num_warmup=1000, num_samples=2000,
progress_bar=True)
sampler.run(jax.random.PRNGKey(7), λnm, fobs)
RuntimeError: Cannot find valid initial parameters. Please check your model again.
I tried the init_strategy
keyword but could not make this change anything.
Does anyone see the obvious mistake? I suspect it comes from the numerical dynamics of the bb_flux
outputs, but there must be something to do, right?