GP with plates

Ciao all,
I am fairly new to Numpyro, so first of all thanks for writing it. I am really liking it.

I have the following problem: I want to fit a hierarchical GP model using Numpyro’s plate notation. Consider this minimal (fictitious) reproducible example:

def rbf(X1, X2, sigma=1.0, rho=1.0, jitter=1.0e-6):
    X1_e = np.expand_dims(X1, 1) / rho
    X2_e = np.expand_dims(X2, 0) / rho
    dists = np.sum((X1_e - X2_e) ** 2, axis=2)    
    K = sigma * np.exp(-0.5 * dists) + np.eye(dists.shape[0]) * jitter
    return K

def model(y, X, n_states):    
    sigma = numpyro.sample("sigma", nd.LogNormal(0.0, 2.0))

    with numpyro.plate("states", size=n_states):                                
        rho = numpyro.sample("rho", nd.LogNormal(0.0, 2.0))
        K = rbf(X, X, sigma, rho)
        L = np.linalg.cholesky(K)
        f_tilde = numpyro.sample("f_tilde", nd.Normal(loc=np.zeros((X.shape[0], 1))))
        f = numpyro.deterministic("f", L @ f_tilde)
    f = f.reshape(-1)    

    noise = numpyro.sample("noise", nd.LogNormal(0.0, 3.0))
    numpyro.sample(
        "y",         
        nd.Normal(f, noise),
        obs=y)

For each of n_states states I am fitting a GP using an RBF kernel with a state-specific lengthscale and population-level variance. This runs fine with no errors. However, when I want to specify state-specific kernel variances, instead, I am getting problems with broadcasting. E.g. this model fails:

def rbf(X1, X2, sigma=1.0, rho=1.0, jitter=1.0e-6):
    X1_e = np.expand_dims(X1, 1) / rho
    X2_e = np.expand_dims(X2, 0) / rho
    dists = np.sum((X1_e - X2_e) ** 2, axis=2)    
    K = sigma * np.exp(-0.5 * dists) + np.eye(dists.shape[0]) * jitter
    return K

def model(y, X, n_states):    
    with numpyro.plate("states", size=n_states):                                
        sigma = numpyro.sample("sigma", nd.LogNormal(0.0, 2.0))
        rho = numpyro.sample("rho", nd.LogNormal(0.0, 2.0))
        K = rbf(X, X, sigma, rho)
        L = np.linalg.cholesky(K)
        f_tilde = numpyro.sample("f_tilde", nd.Normal(loc=np.zeros((X.shape[0], 1))))
        f = numpyro.deterministic("f", L @ f_tilde)
    f = f.reshape(-1)    

    noise = numpyro.sample("noise", nd.LogNormal(0.0, 3.0))
    numpyro.sample(
        "y",         
        nd.Normal(f, noise),
        obs=y)

The only difference is that I am pulling the sigma into the plates block. The error is:

ValueError: Incompatible shapes for broadcasting: ((1, 5), (11, 11))

where 5 is the number of states n_states and 11 is the number of rows of X. Any clues where the problem might be?

Thank you very much for any help!

Best,
Simon

@simonski I think you need to make your rbf function broadcastable. Currently, your implementation assumes sigma and rho are scalars. This could be easily remedied thanks to the awesome jax.vmap:

K = jax.vmap(lambda s, r: rbf(X, X, s, r))(sigma, rho)
1 Like

Great, that did it. Thanks!