NaNs during NUTS

I’m fitting a Gaussian process type model where not all of the inputs (Xs) are known by using mcmc to marginalize them out.

# Get the true function and sample 50 points randomly
def simple_gp(X, kernel, sigma):
    """Draws from a GP"""
    cov = kernel(X)
    cov += torch.eye(len(X)) * 1e-5
    f = dist.MultivariateNormal(torch.zeros(len(X)), cov).sample()
    y = dist.MultivariateNormal(f, torch.eye(len(X)) * sigma).sample()
    return f, y

k = gp.kernels.Matern52(1, variance=torch.Tensor([1.0]), lengthscale=torch.Tensor([0.2]))
X = torch.linspace(-1, 1, 500)
sigma = torch.Tensor([0.1])
torch.manual_seed(7345)
f_true, y_true = simple_gp(X, k, sigma)

n_train = 50
train_inds = torch.LongTensor(np.random.choice(len(X), size=n_train, replace=False).astype(int)).sort()[0]
X_train = X[train_inds]
y_train = y_true[train_inds]

# Fully observe the top 10, but only get the labels for the bottom 40
n_obs = 10
n_train = len(y_train)
n_hyp = n_train - n_obs
y_sort, inds = torch.sort(y_train)
y_obs = y_sort[-n_obs:]
X_obs = X_train[inds][-n_obs:].unsqueeze(1)
y_hyp = y_sort[:n_hyp]
X_hyp = X_train[inds][:n_hyp].unsqueeze(1)

# Specify the model
def Matern52(X1, X2, alpha, rho):
    L2 = torch.abs(torch.sum(X1 ** 2, dim=1).unsqueeze(1) + torch.sum(X2 ** 2, dim=1).unsqueeze(0) - 2 * X1 @ X2.t())
    L = torch.sqrt(L2)
    sqrt5_r = 5**0.5 * L
    return alpha * (1 + sqrt5_r / rho + (5/3) * L2 / rho ** 2) * torch.exp(-sqrt5_r / rho)

def ma_gp_hs(X_obs, n_hyp):
    alpha = pyro.sample('alpha', dist.Uniform(torch.Tensor([0.1]), torch.Tensor([2])))
    sigma = pyro.sample('sigma', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([1])))
    rho = pyro.sample('rho', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([2])))
    X_hyp = pyro.sample('X_hyp', dist.Uniform(torch.ones(n_hyp) - 2, torch.ones(n_hyp))).unsqueeze(1)
    X_train = torch.cat([X_obs, X_hyp])
    # Covariance
    n = len(X_train)
    cov = Matern52(X_train, X_train, alpha, rho) + torch.eye(n) * (sigma + 1e-5)
    L = torch.potrf(cov, upper=False)
    # Likelihood
    return pyro.sample('f', dist.MultivariateNormal(torch.zeros(n), scale_tril=L))
    
def conditioned_gp_regressor(gp_regressor, X_obs, y_train, n_hyp):
    return pyro.poutine.condition(gp_regressor, data={'f': y_train})(X_obs, n_hyp)

nuts_kernel = NUTS(conditioned_gp_regressor, step_size=0.1, adapt_step_size=False)
posterior = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500).run(ma_gp_hs, X_obs, y_sort, n_hyp)

Consistently, no matter the step size, at the fourth iteration X_hyp becomes

tensor([ -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,
         -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,
         -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,  -1.,
         -1.,  -1., nan., nan., nan., nan., nan., nan., nan., nan.])

Would anybody know how to fix this?

My understanding of your particular example is limited but just an observation - in case you do not want to do inference over certain variables but simply sample from them, you can do so directly by calling dist.sample(). All sites designated with pyro.sample will be used to generate hamiltonian trajectories in the integrator. So in your example, if X_hyp is not something that needs to be treated as a latent site, you should use:

dist.Uniform(torch.ones(n_hyp) - 2, torch.ones(n_hyp))).sample().unsqueeze(1)

In this case, X_hyp does need to be treated as a latent site.

In this case, X_hyp does need to be treated as a latent site.

Got it. I suppose someone more familiar with implementing GPs in Pyro might be able to provide some insights. cc. @fehiepsi, @fritzo

I believe the bug lies at the line L = torch.sqrt(L2). torch.sqrt does not have gradient at 0. We avoid this bug by setting L = torch.sqrt(L2 + 1e-12).

By the way, you can define a GP model and use set_data method to set data for a GP model. This way, your code will be much simplified! :slight_smile: Something like this:

kernel = gp.kernels.Matern52(...)
# then set Uniform priors to Matern52's parameters
gp_regressor = gp.models.GPRegression(X_obs, y_obs, kernel=kernel)

def conditioned_gp_regressor(X_obs, y_train, n_hyp):
    X_hyp = pyro.sample(...)
    X = torch.cat([X_obs, X_hyp])
    gp_regressor.set_data(X, y_train)
    return gp_regressor.model()

Some minor things: use torch.tensor instead of torch.Tensor, no need to unsqueeze to make X a 2D tensor (GP module allows X 1D for most kernels), use torch.clamp(min=0) instead of torch.abs() for L2, use GPRegression forward method instead of creating simple_gp function. use gp.kernels.Matern52 and set_prior methods instead of creating Matern52 kernel. :wink:

1 Like

That fixes it.

Thanks so much!