I’m trying to fit an RBF basis function model with K=10 bases to the Friedman function. However, the model simply fails to fit. I’m wondering if anyone can see anything wrong here or has any experience with RBF basis functions where you are learning the centers?
Specifically, I’m trying to fit a sparse basis function model (although it doesn’t work with just regular coefficients).
import json
import matplotlib.pyplot as plt
from jax import random, lax
import jax.numpy as np
from jax.scipy.special import logsumexp
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer.autoguide import AutoDiagonalNormal
from numpyro.infer import SVI, ELBO
from numpyro.infer import MCMC, NUTS, HMC
import numpy as onp
def construct_basis(x,centers,length):
### ( n * p) (p * j )
tmp = (x[:,None,:] -centers)
tmp = (tmp)**2
tmp = np.sum(tmp,axis=(2))
tmp = np.exp(-1/length * tmp)
return (tmp)
def model(x, xtest, y):
# Cosntants
N = y.shape[0]
k = 10
noise = 0#numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
length = .05# numpyro.sample("kernel_length", dist.Gamma(1,1))
centers = numpyro.sample('center',dist.Uniform(np.zeros((k,5)),np.ones((k,5))))
kern_tot = construct_basis(x,centers,length)
lambdas = numpyro.sample("lambdas", dist.HalfCauchy(np.ones(k)))
tau = numpyro.sample("tau", dist.HalfCauchy(np.ones(1)))
# note that this reparameterization (i.e. coordinate transformation) improves
# posterior geometry and makes NUTS sampling more efficient
unscaled_betas = numpyro.sample("unscaled_betas", dist.Normal(0.0, np.ones(k)))
scaled_betas = numpyro.deterministic("betas", tau * lambdas * unscaled_betas)
y_hat = np.dot(kern_tot,scaled_betas)
numpyro.deterministic("y_hat",y_hat)
numpyro.sample("y_obs", dist.Normal(y_hat, .005),obs=y)
# Local variables version:
def f(x):
return 10*onp.sin(onp.pi*x[:,0]*x[:,1])+20*(x[:,2]-.5)**2+10*x[:,3]+5*x[:,4]
sigma = .005 # noise sd
n = 1000 # number of observations
x = onp.random.uniform(0,1,(n,5))
y = f(x) + onp.random.normal(0,sigma,n)
xtest = onp.random.uniform(0,1,(n,5))
ytest = f(xtest)
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
kernel = NUTS(model)
num_samples = 2000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples)
mcmc.run(
rng_key_, x=x,y=y,xtest=xtest
)
mcmc.print_summary()
samples_1 = mcmc.get_samples()
import matplotlib.pyplot as plt
plt.plot(np.mean(samples_1['y_hat'],axis=0),y,'ro')
plt.show()
``