I know this must be a stupid bug, but I cannot find what triggers this issue. I have a Planck black body radiation model that should be relatively straightforward to sample.
import jax.numpy as jnp import jax @jax.jit def bb_flux( λnm: "nanometers", T: "K", amp: float = 1.) -> jnp.array: """ default units blackbody as a flux distribution as a function of wavelength, temperature and amplitude. :param lam: wavelength in nm :param amp: dimensionless normalization factor :param teff: temperature in Kelvins :return: evaluation of the blackbody radiation in flam units (erg/s/cm2/AA) """ # Constants h = 6.626e-34 # Planck's constant [J*s = m**2 * kg / s] c = 2.998e8 # Speed of light [m/s] k = 1.381e-23 # Boltzmann constant [J/K] # Planck's law λ = λnm * 1e-9 return amp * 2 * h * c ** 2/ (λ ** 5 * (jnp.exp(h * c/(λ * k * T)) - 1))
The following is one of the attempts to describe the model.
import numpyro import numpyro.distributions as dist def model(λm: jnp.array, fobs: jnp.array = None): """ Model of the fobs | fpred, λ, ω """ T = numpyro.sample('T', dist.Uniform(3000., 10000.)) log_amp = numpyro.sample('log_amp', dist.Uniform(-0.01, 0.01)) amp = jnp.power(10., log_amp) fhat = numpyro.deterministic('fhat', bb_flux(λnm, T, amp)) residuals = jnp.nan_to_num(fobs - fhat) numpyro.factor('obs', dist.Normal(0, 1e4).log_prob(residuals).sum())
In this model, the
factor could be replaced by
numpyro.sample('obs', dist.normal(fhat, 1e4), obs=fobs).
(the issue remains)
The inference does not run because of an initialization error
λnm = np.linspace(100, 900, 1000) teff_true = 4500 # _K amp_true = 1. fobs = amp_true * bb_flux(λnm, teff_true, amp_true) from numpyro import infer kernel = infer.NUTS(model, init_strategy=infer.init_to_median()) sampler = infer.MCMC( kernel, num_chains=1, num_warmup=1000, num_samples=2000, progress_bar=True) sampler.run(jax.random.PRNGKey(7), λnm, fobs)
RuntimeError: Cannot find valid initial parameters. Please check your model again.
I tried the
init_strategy keyword but could not make this change anything.
Does anyone see the obvious mistake? I suspect it comes from the numerical dynamics of the
bb_flux outputs, but there must be something to do, right?