 # GP with plates

Ciao all,
I am fairly new to Numpyro, so first of all thanks for writing it. I am really liking it.

I have the following problem: I want to fit a hierarchical GP model using Numpyro’s plate notation. Consider this minimal (fictitious) reproducible example:

``````def rbf(X1, X2, sigma=1.0, rho=1.0, jitter=1.0e-6):
X1_e = np.expand_dims(X1, 1) / rho
X2_e = np.expand_dims(X2, 0) / rho
dists = np.sum((X1_e - X2_e) ** 2, axis=2)
K = sigma * np.exp(-0.5 * dists) + np.eye(dists.shape) * jitter
return K

def model(y, X, n_states):
sigma = numpyro.sample("sigma", nd.LogNormal(0.0, 2.0))

with numpyro.plate("states", size=n_states):
rho = numpyro.sample("rho", nd.LogNormal(0.0, 2.0))
K = rbf(X, X, sigma, rho)
L = np.linalg.cholesky(K)
f_tilde = numpyro.sample("f_tilde", nd.Normal(loc=np.zeros((X.shape, 1))))
f = numpyro.deterministic("f", L @ f_tilde)
f = f.reshape(-1)

noise = numpyro.sample("noise", nd.LogNormal(0.0, 3.0))
numpyro.sample(
"y",
nd.Normal(f, noise),
obs=y)
``````

For each of `n_states` states I am fitting a GP using an RBF kernel with a state-specific lengthscale and population-level variance. This runs fine with no errors. However, when I want to specify state-specific kernel variances, instead, I am getting problems with broadcasting. E.g. this model fails:

``````def rbf(X1, X2, sigma=1.0, rho=1.0, jitter=1.0e-6):
X1_e = np.expand_dims(X1, 1) / rho
X2_e = np.expand_dims(X2, 0) / rho
dists = np.sum((X1_e - X2_e) ** 2, axis=2)
K = sigma * np.exp(-0.5 * dists) + np.eye(dists.shape) * jitter
return K

def model(y, X, n_states):
with numpyro.plate("states", size=n_states):
sigma = numpyro.sample("sigma", nd.LogNormal(0.0, 2.0))
rho = numpyro.sample("rho", nd.LogNormal(0.0, 2.0))
K = rbf(X, X, sigma, rho)
L = np.linalg.cholesky(K)
f_tilde = numpyro.sample("f_tilde", nd.Normal(loc=np.zeros((X.shape, 1))))
f = numpyro.deterministic("f", L @ f_tilde)
f = f.reshape(-1)

noise = numpyro.sample("noise", nd.LogNormal(0.0, 3.0))
numpyro.sample(
"y",
nd.Normal(f, noise),
obs=y)
``````

The only difference is that I am pulling the `sigma` into the plates block. The error is:

``````ValueError: Incompatible shapes for broadcasting: ((1, 5), (11, 11))
``````

where 5 is the number of states `n_states` and 11 is the number of rows of `X`. Any clues where the problem might be?

Thank you very much for any help!

Best,
Simon

@simonski I think you need to make your `rbf` function broadcastable. Currently, your implementation assumes `sigma` and `rho` are scalars. This could be easily remedied thanks to the awesome `jax.vmap`:

``````K = jax.vmap(lambda s, r: rbf(X, X, s, r))(sigma, rho)
``````
1 Like

Great, that did it. Thanks!