Ciao all,
I am fairly new to Numpyro, so first of all thanks for writing it. I am really liking it.
I have the following problem: I want to fit a hierarchical GP model using Numpyro’s plate notation. Consider this minimal (fictitious) reproducible example:
def rbf(X1, X2, sigma=1.0, rho=1.0, jitter=1.0e-6):
X1_e = np.expand_dims(X1, 1) / rho
X2_e = np.expand_dims(X2, 0) / rho
dists = np.sum((X1_e - X2_e) ** 2, axis=2)
K = sigma * np.exp(-0.5 * dists) + np.eye(dists.shape[0]) * jitter
return K
def model(y, X, n_states):
sigma = numpyro.sample("sigma", nd.LogNormal(0.0, 2.0))
with numpyro.plate("states", size=n_states):
rho = numpyro.sample("rho", nd.LogNormal(0.0, 2.0))
K = rbf(X, X, sigma, rho)
L = np.linalg.cholesky(K)
f_tilde = numpyro.sample("f_tilde", nd.Normal(loc=np.zeros((X.shape[0], 1))))
f = numpyro.deterministic("f", L @ f_tilde)
f = f.reshape(-1)
noise = numpyro.sample("noise", nd.LogNormal(0.0, 3.0))
numpyro.sample(
"y",
nd.Normal(f, noise),
obs=y)
For each of n_states
states I am fitting a GP using an RBF kernel with a state-specific lengthscale and population-level variance. This runs fine with no errors. However, when I want to specify state-specific kernel variances, instead, I am getting problems with broadcasting. E.g. this model fails:
def rbf(X1, X2, sigma=1.0, rho=1.0, jitter=1.0e-6):
X1_e = np.expand_dims(X1, 1) / rho
X2_e = np.expand_dims(X2, 0) / rho
dists = np.sum((X1_e - X2_e) ** 2, axis=2)
K = sigma * np.exp(-0.5 * dists) + np.eye(dists.shape[0]) * jitter
return K
def model(y, X, n_states):
with numpyro.plate("states", size=n_states):
sigma = numpyro.sample("sigma", nd.LogNormal(0.0, 2.0))
rho = numpyro.sample("rho", nd.LogNormal(0.0, 2.0))
K = rbf(X, X, sigma, rho)
L = np.linalg.cholesky(K)
f_tilde = numpyro.sample("f_tilde", nd.Normal(loc=np.zeros((X.shape[0], 1))))
f = numpyro.deterministic("f", L @ f_tilde)
f = f.reshape(-1)
noise = numpyro.sample("noise", nd.LogNormal(0.0, 3.0))
numpyro.sample(
"y",
nd.Normal(f, noise),
obs=y)
The only difference is that I am pulling the sigma
into the plates block. The error is:
ValueError: Incompatible shapes for broadcasting: ((1, 5), (11, 11))
where 5 is the number of states n_states
and 11 is the number of rows of X
. Any clues where the problem might be?
Thank you very much for any help!
Best,
Simon