I am working on a model that treats a 2D signal as a sum of 1D signals smoothed out in the 2nd dimension:
I have used a GP to model the amplitude, width, and center with HSGP. Now my problem is that when I use fixed hyperparameters, everything works, but if I let the hyperparameters vary, it all fails. I have some reproducing code:
import jax
import matplotlib.pyplot as plt
from jax import random
import jax.numpy as jnp
from numpyro.infer import SVI, Trace_ELBO, autoguide, Predictive
from numpyro.optim import Adam
import numpyro
from numpyro.contrib.hsgp.laplacian import eigenfunctions
from numpyro.contrib.hsgp.spectral_densities import (
diag_spectral_density_squared_exponential,
)
import numpyro.distributions as dist
rng_key = random.PRNGKey(seed=42)
x_train_centered = jnp.array([[_a,_b] for _a in jnp.linspace(-1/2,1/2,50) for _b in jnp.linspace(-1/2, 1/2, 50)])
y = jnp.asarray(500 * jnp.exp(-0.5 * x_train_centered[:,1]**2 / 0.1**2)).astype(int)
plt.scatter(*x_train_centered.T, c=y, marker='s')
plt.colorbar()
non_centered = True
ell = 5
y_train_obs = y
monitors = 1.
def model(x, y=None, monitors=1., max_k=1, basis_size=300, min_sigma=0.05, make_error=True):
length = jnp.array([0.2, 0.1, 0.3])
if make_error:
# log_alpha = numpyro.sample('log_alpha', dist.Normal(2*jnp.log(jnp.array([100., 1., 2.])), 1e-2))
# alpha = jnp.exp(log_alpha)
alpha = numpyro.sample('alpha', dist.TruncatedNormal(jnp.array([100., 1., 2.]), 0.1, low=1e-6))**2
else:
alpha = jnp.array([100., 1., 2.])**2
offset_guess = jnp.array([500, 0.5, 0.01])
offsets = numpyro.sample('offsets', dist.Normal(offset_guess, offset_guess/2))
# length = numpyro.sample('length', dist.TruncatedNormal(length, length/2, low=0.1, high=1.))
# alpha = numpyro.sample('alpha', dist.InverseGamma(jnp.array([10, 20., 10]), rate=jnp.array([1_000, 10., 10])))
phi = eigenfunctions(x[:,0], ell, basis_size).T
lam = diag_spectral_density_squared_exponential(alpha=alpha, length=length, ell=ell, m=basis_size)
beta = numpyro.sample('beta', dist.Normal().expand([3, max_k, basis_size]))
amp, f, sigma = jnp.einsum('ij,ip,jki->jkp', lam**0.5, phi, beta) + offsets[:,None,None]
amp = jax.nn.softplus(amp)
sigma = jax.nn.softplus(sigma)
rate = (amp * jnp.exp(-0.5 * (f-x[:,1])**2 / (sigma**2 + min_sigma**2))).sum(0) + 1e-4
numpyro.sample('likelihood', dist.Poisson(rate*monitors), obs=y)
numpyro.deterministic('_alpha', alpha)
numpyro.deterministic('_length', length)
numpyro.deterministic('_beta', beta)
numpyro.deterministic('_amp', amp)
numpyro.deterministic('_sigma', sigma)
numpyro.deterministic('_f', f)
numpyro.deterministic('_rate', rate)
from numpyro.infer import init_to_feasible, init_to_mean, init_to_median, init_to_uniform, init_to_sample, init_to_value
max_k = 1
for init in [init_to_feasible, init_to_mean, init_to_median, init_to_uniform, init_to_sample]:
try:
optimizer = Adam(1e-3)
guide = autoguide.AutoDiagonalNormal(model, init_loc_fn=init)
# guide = autoguide.AutoNormal(model, init_loc_fn=init_to_value(values={"alpha": jnp.array([100., 1., 2.])}))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_state = svi.init(rng_key=jax.random.PRNGKey(0), x=x_train_centered, y=y_train_obs, monitors=monitors, max_k=max_k)
@jax.jit
def svi_update(svi_state):
return svi.update(svi_state=svi_state, x=x_train_centered, y=y_train_obs, monitors=monitors, max_k=max_k)
svi_update(svi_state)
num_steps = 50_000
try:
for i in range(num_steps):
svi_state, loss = svi_update(svi_state)
if i % 1_000 == 0:
print(f"Step {i}: loss = {loss:.3f}")
if jnp.isnan(loss):
raise Exception('nan loss')
except KeyboardInterrupt:
pass
params = svi.get_params(svi_state)
predictive = Predictive(model, guide=guide, params=params, num_samples=5)
rng_key, rng_subkey = jax.random.split(rng_key)
posterior_samples = predictive(
rng_subkey, x_train_centered, y_train_obs,
monitors=monitors, max_k=max_k
)
except Exception as e:
print(f'init {init} failed')
# print(e)
continue
Everything that needs to be protected by non-negativity has a softplus applied. The only change in this model is replacing a fixed value with a TruncatedNormal at the same fixed value ± epsilon. I don’t understand how one setup can fail and the other doesn’t. The error message is “Cannot find valid initial parameters. Please check your model again.” and “UserWarning: Out-of-support values provided to log prob method. The value argument should be within the support. alpha = numpyro.sample(‘alpha’, dist.TruncatedNormal(jnp.array([100., 1., 2.]), 0.1, low=1e-6))**2”