Can't learn simple radial basis function model

I’m trying to fit an RBF basis function model with K=10 bases to the Friedman function. However, the model simply fails to fit. I’m wondering if anyone can see anything wrong here or has any experience with RBF basis functions where you are learning the centers?

Specifically, I’m trying to fit a sparse basis function model (although it doesn’t work with just regular coefficients).

import json
import matplotlib.pyplot as plt
from jax import random, lax
import jax.numpy as np
from jax.scipy.special import logsumexp
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer.autoguide import AutoDiagonalNormal
from numpyro.infer import SVI, ELBO
from numpyro.infer import MCMC, NUTS, HMC
import numpy as onp

def construct_basis(x,centers,length):
		### ( n * p)  (p  * j ) 
		tmp = (x[:,None,:] -centers)
		tmp = (tmp)**2
		tmp = np.sum(tmp,axis=(2))
		tmp = np.exp(-1/length * tmp)
		return (tmp)

def model(x, xtest, y):
	# Cosntants
	N = y.shape[0]
	k = 10

	noise = 0#numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
	length = .05# numpyro.sample("kernel_length", dist.Gamma(1,1))

	centers = numpyro.sample('center',dist.Uniform(np.zeros((k,5)),np.ones((k,5))))

	kern_tot = construct_basis(x,centers,length)

	lambdas = numpyro.sample("lambdas", dist.HalfCauchy(np.ones(k)))
	tau = numpyro.sample("tau", dist.HalfCauchy(np.ones(1)))

	# note that this reparameterization (i.e. coordinate transformation) improves
	# posterior geometry and makes NUTS sampling more efficient
	unscaled_betas = numpyro.sample("unscaled_betas", dist.Normal(0.0, np.ones(k)))
	scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas)
	y_hat =,scaled_betas)

	numpyro.sample("y_obs", dist.Normal(y_hat, .005),obs=y)


		# Local variables version:

def f(x):
  return 10*onp.sin(onp.pi*x[:,0]*x[:,1])+20*(x[:,2]-.5)**2+10*x[:,3]+5*x[:,4]

sigma = .005 # noise sd
n = 1000 # number of observations
x = onp.random.uniform(0,1,(n,5)) 
y = f(x) + onp.random.normal(0,sigma,n)

xtest = onp.random.uniform(0,1,(n,5))
ytest = f(xtest)

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)	
	rng_key_, x=x,y=y,xtest=xtest
samples_1 = mcmc.get_samples()

import matplotlib.pyplot as plt


your basis looks a bit weird. are you sure this is what you want? with so many highly correlated centers? if you want a gaussian process you might follow one of these examples:

Yeah, I want a fixed number of basis functions so the size doesn’t grow with n (unlike GP). I think the centers do have a “label switching” problem.

hsgp uses a fixed number of basis functions if i recall correctly