Also, in Hilbert Space Approximation Gaussian Process Module — NumPyro documentation we are using an amplitude coming from a distribution (and Example: Hilbert space approximation for Gaussian processes (multidimensional) — NumPyro documentation in two dimensions).
Maybe can you share a very minimal code to reproduce the error (without plotting commands and just the init line that generates it
)
This works for me:
import jax
from numpyro.infer import init_to_mean
import matplotlib.pyplot as plt
from jax import random
import jax.numpy as jnp
from numpyro.infer import SVI, Trace_ELBO, autoguide, Predictive
from numpyro.optim import Adam
import numpyro
from numpyro.contrib.hsgp.laplacian import eigenfunctions
from numpyro.contrib.hsgp.spectral_densities import (
diag_spectral_density_squared_exponential,
)
import numpyro.distributions as dist
rng_key = random.PRNGKey(seed=42)
x_train_centered = jnp.array([[_a,_b] for _a in jnp.linspace(-1/2,1/2,50) for _b in jnp.linspace(-1/2, 1/2, 50)])
y = jnp.asarray(500 * jnp.exp(-0.5 * x_train_centered[:,1]**2 / 0.1**2)).astype(int)
non_centered = True
ell = 5
y_train_obs = y
monitors = 1.
def model(x, y=None, monitors=1., max_k=1, basis_size=300, min_sigma=0.05):
length = jnp.array([0.2, 0.1, 0.3])
alpha = numpyro.sample('alpha', dist.TruncatedNormal(jnp.array([100., 1., 2.]), 0.1, low=1e-6))**2
offset_guess = jnp.array([500, 0.5, 0.01])
offsets = numpyro.sample('offsets', dist.Normal(offset_guess, offset_guess/2))
phi = eigenfunctions(x[:,0], ell, basis_size).T
lam = diag_spectral_density_squared_exponential(alpha=alpha, length=length, ell=ell, m=basis_size, dim=1)
beta = numpyro.sample('beta', dist.Normal().expand([3, max_k, basis_size]))
amp, f, sigma = jnp.einsum('ij,ip,jki->jkp', lam**0.5, phi, beta) + offsets[:,None,None]
amp = jax.nn.softplus(amp)
sigma = jax.nn.softplus(sigma)
rate = (amp * jnp.exp(-0.5 * (f-x[:,1])**2 / (sigma**2 + min_sigma**2))).sum(0) + 1e-4
max_k = 1
init = init_to_mean
optimizer = Adam(1e-3)
guide = autoguide.AutoDiagonalNormal(model, init_loc_fn=init)
# guide = autoguide.AutoNormal(model, init_loc_fn=init_to_value(values={"alpha": jnp.array([100., 1., 2.])}))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_state = svi.init(rng_key=jax.random.PRNGKey(0), x=x_train_centered, y=y_train_obs, monitors=monitors, max_k=max_k)
@jax.jit
def svi_update(svi_state):
return svi.update(svi_state=svi_state, x=x_train_centered, y=y_train_obs, monitors=monitors, max_k=max_k)
svi_update(svi_state)
num_steps = 20_000
for i in range(num_steps):
svi_state, loss = svi_update(svi_state)
if i % 1_000 == 0:
print(f"Step {i}: loss = {loss:.3f}")
if jnp.isnan(loss):
raise Exception('nan loss')
params = svi.get_params(svi_state)
predictive = Predictive(model, guide=guide, params=params, num_samples=5)
rng_key, rng_subkey = jax.random.split(rng_key)
posterior_samples = predictive(
rng_subkey, x_train_centered, y_train_obs,
monitors=monitors, max_k=max_k
)
Outputs:
Step 0: loss = 1990.518
Step 1000: loss = 931.560
Step 2000: loss = 1073.600
Step 3000: loss = 1628.495
Step 4000: loss = 36.412
Step 5000: loss = 736.634
Step 6000: loss = 33.537
Step 7000: loss = 123.824
Step 8000: loss = 3.094
Step 9000: loss = 44.443
Step 10000: loss = 57.983
Step 11000: loss = 9.594
Step 12000: loss = 3.244
Step 13000: loss = 4.508
Step 14000: loss = 13.230
Step 15000: loss = 6.302
Step 16000: loss = 11.743
Step 17000: loss = 5.194
Step 18000: loss = 3.492
Step 19000: loss = 1.519