HSGP fails when hyperparameter are used

I am working on a model that treats a 2D signal as a sum of 1D signals smoothed out in the 2nd dimension:

y_i \sim Poisson(M * (\epsilon + \sum_k^{K_{max}} amp_k(x_1) \exp(-0.5 \frac{(f_k(x_1)-x_2)^2}{ \sigma_k(x_1)^2})))

I have used a GP to model the amplitude, width, and center with HSGP. Now my problem is that when I use fixed hyperparameters, everything works, but if I let the hyperparameters vary, it all fails. I have some reproducing code:

import jax

import matplotlib.pyplot as plt
from jax import random
import jax.numpy as jnp
from numpyro.infer import SVI, Trace_ELBO, autoguide, Predictive
from numpyro.optim import Adam
import numpyro
from numpyro.contrib.hsgp.laplacian import eigenfunctions
from numpyro.contrib.hsgp.spectral_densities import (
    diag_spectral_density_squared_exponential,
)
import numpyro.distributions as dist

rng_key = random.PRNGKey(seed=42)
x_train_centered = jnp.array([[_a,_b] for _a in jnp.linspace(-1/2,1/2,50) for _b in jnp.linspace(-1/2, 1/2, 50)])
y = jnp.asarray(500 * jnp.exp(-0.5 * x_train_centered[:,1]**2 / 0.1**2)).astype(int)
plt.scatter(*x_train_centered.T, c=y, marker='s')
plt.colorbar()
non_centered = True
ell = 5

y_train_obs = y
monitors = 1.
def model(x, y=None, monitors=1., max_k=1, basis_size=300, min_sigma=0.05, make_error=True):
    length = jnp.array([0.2, 0.1, 0.3])
    if make_error:
        # log_alpha = numpyro.sample('log_alpha', dist.Normal(2*jnp.log(jnp.array([100., 1., 2.])), 1e-2))
        # alpha = jnp.exp(log_alpha)
        alpha = numpyro.sample('alpha', dist.TruncatedNormal(jnp.array([100., 1., 2.]), 0.1, low=1e-6))**2
    else:
        alpha = jnp.array([100., 1., 2.])**2

    offset_guess = jnp.array([500, 0.5, 0.01])
    offsets = numpyro.sample('offsets', dist.Normal(offset_guess, offset_guess/2))
    # length = numpyro.sample('length', dist.TruncatedNormal(length, length/2, low=0.1, high=1.))
    # alpha = numpyro.sample('alpha', dist.InverseGamma(jnp.array([10, 20., 10]), rate=jnp.array([1_000, 10., 10])))

    phi = eigenfunctions(x[:,0], ell, basis_size).T
    lam = diag_spectral_density_squared_exponential(alpha=alpha, length=length, ell=ell, m=basis_size)
    beta = numpyro.sample('beta', dist.Normal().expand([3, max_k, basis_size])) 
    amp, f, sigma = jnp.einsum('ij,ip,jki->jkp', lam**0.5, phi, beta) + offsets[:,None,None]
    amp = jax.nn.softplus(amp)
    sigma = jax.nn.softplus(sigma)
    rate = (amp * jnp.exp(-0.5 * (f-x[:,1])**2 / (sigma**2 + min_sigma**2))).sum(0) + 1e-4
    numpyro.sample('likelihood', dist.Poisson(rate*monitors), obs=y)
    numpyro.deterministic('_alpha', alpha)
    numpyro.deterministic('_length', length)
    numpyro.deterministic('_beta', beta)
    numpyro.deterministic('_amp', amp)
    numpyro.deterministic('_sigma', sigma)
    numpyro.deterministic('_f', f)
    numpyro.deterministic('_rate', rate)
    
from numpyro.infer import init_to_feasible, init_to_mean, init_to_median, init_to_uniform, init_to_sample, init_to_value
max_k = 1
for init in [init_to_feasible, init_to_mean, init_to_median, init_to_uniform, init_to_sample]:
    try:
        optimizer = Adam(1e-3)
        guide = autoguide.AutoDiagonalNormal(model, init_loc_fn=init)
        # guide = autoguide.AutoNormal(model, init_loc_fn=init_to_value(values={"alpha": jnp.array([100., 1., 2.])}))
        svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
        svi_state = svi.init(rng_key=jax.random.PRNGKey(0), x=x_train_centered, y=y_train_obs, monitors=monitors, max_k=max_k)
        @jax.jit
        def svi_update(svi_state):
            return svi.update(svi_state=svi_state, x=x_train_centered, y=y_train_obs, monitors=monitors, max_k=max_k)
        svi_update(svi_state)
        num_steps = 50_000
        try:
            for i in range(num_steps):
                svi_state, loss = svi_update(svi_state)
                if i % 1_000 == 0:
                    print(f"Step {i}: loss = {loss:.3f}")
                if jnp.isnan(loss):
                    raise Exception('nan loss')
        except KeyboardInterrupt:
            pass
        params = svi.get_params(svi_state)


        predictive = Predictive(model, guide=guide, params=params, num_samples=5)
        rng_key, rng_subkey = jax.random.split(rng_key)
        posterior_samples = predictive(
            rng_subkey, x_train_centered, y_train_obs,
            monitors=monitors, max_k=max_k
        )
    except Exception as e:
        print(f'init {init} failed')
        # print(e)
        continue

Everything that needs to be protected by non-negativity has a softplus applied. The only change in this model is replacing a fixed value with a TruncatedNormal at the same fixed value ± epsilon. I don’t understand how one setup can fail and the other doesn’t. The error message is “Cannot find valid initial parameters. Please check your model again.” and “UserWarning: Out-of-support values provided to log prob method. The value argument should be within the support. alpha = numpyro.sample(‘alpha’, dist.TruncatedNormal(jnp.array([100., 1., 2.]), 0.1, low=1e-6))**2”

@juanitorduz Could you help @elliottperryman here?

I will take a look into this in the next days :folded_hands: .

Actually, I am having problems reproducing the code. This first think I spotted is that you are missing the required dim parameter in

diag_spectral_density_squared_exponential(alpha=alpha, length=length, ell=ell, m=basis_size)

So it has to be like:

diag_spectral_density_squared_exponential(alpha=alpha, length=length, ell=ell, m=basis_size, dim=1)

(if it’s one dimensional)

Also, in Hilbert Space Approximation Gaussian Process Module — NumPyro documentation we are using an amplitude coming from a distribution (and Example: Hilbert space approximation for Gaussian processes (multidimensional) — NumPyro documentation in two dimensions).

Maybe can you share a very minimal code to reproduce the error (without plotting commands and just the init line that generates it :slight_smile: )

This works for me:

import jax
from numpyro.infer import init_to_mean
import matplotlib.pyplot as plt
from jax import random
import jax.numpy as jnp
from numpyro.infer import SVI, Trace_ELBO, autoguide, Predictive
from numpyro.optim import Adam
import numpyro
from numpyro.contrib.hsgp.laplacian import eigenfunctions
from numpyro.contrib.hsgp.spectral_densities import (
    diag_spectral_density_squared_exponential,
)
import numpyro.distributions as dist

rng_key = random.PRNGKey(seed=42)
x_train_centered = jnp.array([[_a,_b] for _a in jnp.linspace(-1/2,1/2,50) for _b in jnp.linspace(-1/2, 1/2, 50)])
y = jnp.asarray(500 * jnp.exp(-0.5 * x_train_centered[:,1]**2 / 0.1**2)).astype(int)
non_centered = True
ell = 5

y_train_obs = y
monitors = 1.
def model(x, y=None, monitors=1., max_k=1, basis_size=300, min_sigma=0.05):
    length = jnp.array([0.2, 0.1, 0.3])

    alpha = numpyro.sample('alpha', dist.TruncatedNormal(jnp.array([100., 1., 2.]), 0.1, low=1e-6))**2
    offset_guess = jnp.array([500, 0.5, 0.01])
    offsets = numpyro.sample('offsets', dist.Normal(offset_guess, offset_guess/2))

    phi = eigenfunctions(x[:,0], ell, basis_size).T
    lam = diag_spectral_density_squared_exponential(alpha=alpha, length=length, ell=ell, m=basis_size, dim=1)
    beta = numpyro.sample('beta', dist.Normal().expand([3, max_k, basis_size])) 
    amp, f, sigma = jnp.einsum('ij,ip,jki->jkp', lam**0.5, phi, beta) + offsets[:,None,None]
    amp = jax.nn.softplus(amp)
    sigma = jax.nn.softplus(sigma)
    rate = (amp * jnp.exp(-0.5 * (f-x[:,1])**2 / (sigma**2 + min_sigma**2))).sum(0) + 1e-4

max_k = 1

init = init_to_mean

    
optimizer = Adam(1e-3)
guide = autoguide.AutoDiagonalNormal(model, init_loc_fn=init)
# guide = autoguide.AutoNormal(model, init_loc_fn=init_to_value(values={"alpha": jnp.array([100., 1., 2.])}))
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_state = svi.init(rng_key=jax.random.PRNGKey(0), x=x_train_centered, y=y_train_obs, monitors=monitors, max_k=max_k)


@jax.jit
def svi_update(svi_state):
    return svi.update(svi_state=svi_state, x=x_train_centered, y=y_train_obs, monitors=monitors, max_k=max_k)
svi_update(svi_state)
num_steps = 20_000

for i in range(num_steps):
    svi_state, loss = svi_update(svi_state)
    if i % 1_000 == 0:
        print(f"Step {i}: loss = {loss:.3f}")
    if jnp.isnan(loss):
        raise Exception('nan loss')

params = svi.get_params(svi_state)


predictive = Predictive(model, guide=guide, params=params, num_samples=5)
rng_key, rng_subkey = jax.random.split(rng_key)
posterior_samples = predictive(
    rng_subkey, x_train_centered, y_train_obs,
    monitors=monitors, max_k=max_k
)

Outputs:

Step 0: loss = 1990.518
Step 1000: loss = 931.560
Step 2000: loss = 1073.600
Step 3000: loss = 1628.495
Step 4000: loss = 36.412
Step 5000: loss = 736.634
Step 6000: loss = 33.537
Step 7000: loss = 123.824
Step 8000: loss = 3.094
Step 9000: loss = 44.443
Step 10000: loss = 57.983
Step 11000: loss = 9.594
Step 12000: loss = 3.244
Step 13000: loss = 4.508
Step 14000: loss = 13.230
Step 15000: loss = 6.302
Step 16000: loss = 11.743
Step 17000: loss = 5.194
Step 18000: loss = 3.492
Step 19000: loss = 1.519