How to condition several independent components on different data

Hi, I am trying to fit a model in numpyro which looks like the following:

  1. Draw n hyperparameters p_i for i=1,...,n
  2. Using these n hyperparameters fit n conditionally independent models where each model uses p_i as its hyperparameter and a different set of input-output pairs.

As a simple trivial example (without the hyperparameter draws but using different (x,y) pairs for each independent GP) I attempted the following:

var = params['var']
noise = params['noise']
length = params['length']
def model(Xs, Ys, jitter=1e-6):
    N, n_s = Xs.shape
    for i in range(n_s):
        x = Xs[:,i]
        y = Ys[:,i]
        K = rbf_covariance(var, length, noise, x.reshape(-1,1), x.reshape(-1,1), jitter=jitter)
        
        numpyro.sample(
            f"Ys[:,{i}]",
            dist.MultivariateNormal(loc=np.zeros(K.shape[0]), covariance_matrix=K),
            obs=y
        )

This runs and the fitting completes however when I try to do mcmc.print_summary() I get the following error: ValueError: max() arg is an empty sequence. Also running mcmc.get_samples() returns an empty dictionary.

I have tried searching through the docs but cannot find anything to help. Any help or guidance would be appreciated :slight_smile:

hello, i think you will need to post a complete runnable script that triggers the error if there’s any hope of figuring out what’s going on.

Hi, thanks for the quick reply! Here is a complete runnable script. A big part is just getting a toy dataset. The actual model bit is towards the end. The last line gives an error. I was looking into using plates but am not sure how to use a plate to index the particular input-output pair or hyperparameters for each site.

import os
import time
import jax.numpy as np
from jax import random
from jax import vmap, jit
from functools import partial
import numpyro
from numpyro.infer import init_to_median, Predictive, MCMC, NUTS
import numpyro.distributions as dist


@partial(jit, static_argnums=(5, 6))
def rbf_covariance(var, length, noise, x, xp, jitter=1.0e-6,
                   include_noise=True):
    diff = np.expand_dims(x / length, 1) - np.expand_dims(xp / length, 0)
    Z = var * np.exp(-0.5 * np.sum(diff**2, axis=2))  # ! axis = 2 ??

    if include_noise:
        return Z + (noise + jitter) * np.eye(x.shape[0])
    else:
        return Z

## generate dataset
###############################
def univariate_gp(x, y, mu, var, noise, length, jitter=1e-06):
    # compute kernel
    K = rbf_covariance(var, length, noise, x.reshape(-1,1), x.reshape(-1,1), jitter=jitter)

    numpyro.sample(
        "obs_y",
        dist.MultivariateNormal(loc=mu, covariance_matrix=K),
        obs=y,
    )

def gen_outer_gp(key, proj_data, mu, var, noise, length, jitter=1e-06):
    predictive = Predictive(univariate_gp, num_samples=1)
    pred = predictive(key, proj_data, None, mu, var, noise, length, jitter=jitter)
    return pred['obs_y'].flatten()

input_key = random.PRNGKey(778989)
projs_key = random.PRNGKey(3257)
outer_key = random.PRNGKey(2357)
D = 2
N = 50
n_s = 2
var = 1.0
noise = 0.1
length = 0.5

x1 = np.linspace(0,1,N).reshape(-1,1)
x2 = np.linspace(1,2,N).reshape(-1,1)
Xs = np.hstack((x1,x2))

def mu_f(x):
    y = x + 0.2 * (x ** 3) + 0.5 * ((0.5 + x) ** 2) * np.sin(4.0 * x)
    return y

mu_s = vmap(mu_f, in_axes=1, out_axes=1)(Xs)
outer_keys = random.split(outer_key, n_s)

Ys = vmap(lambda key, x, mu: gen_outer_gp(key, x, mu, var, noise, length), in_axes=(0,1,1), out_axes=1)(outer_keys, Xs, mu_s)
#######################################################


## model

def model(Xs, Ys, jitter=1e-6):
    N, n_s = Xs.shape
    
    for i in range(n_s):
        x = Xs[:,i]
        y = Ys[:,i]
        K = rbf_covariance(var, length, noise, x.reshape(-1,1), x.reshape(-1,1), jitter=jitter)
        
        numpyro.sample(
            f"Ys[:,{i}]",
            dist.MultivariateNormal(loc=np.zeros(K.shape[0]), covariance_matrix=K),
            obs=y
        )

# options for model
mcmc_config = {'num_warmup' : 1000, 'num_samples' : 1000, 'num_chains' : 1, 'thinning' : 2, 'init_strategy' : init_to_median(num_samples=10)}

#
seed = 342757
train_key = random.PRNGKey(seed)

# helper function for doing hmc inference
def run_inference(rng_key, mcmc_config, model, *args):
    num_warmup = mcmc_config['num_warmup']
    num_samples = mcmc_config['num_samples']
    num_chains = mcmc_config['num_chains']
    thinning = mcmc_config['thinning']
    init_strategy = mcmc_config['init_strategy']

    start = time.time()

    kernel = NUTS(model, init_strategy=init_strategy)
    
    mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, thinning=thinning,
                progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True)
    
    mcmc.run(rng_key, *args)
    
    print('\nMCMC elapsed time:', time.time() - start)
    return mcmc

mcmc = run_inference(train_key, mcmc_config, model, Xs, Ys)

# throws error
mcmc.print_summary()

# this would return {}
# mcmc.get_samples()

oh sorry this is basically expected. your model has no latent variables defined therefore there is no MCMC to be done. see e.g. the gp example

Ahh yes that makes sense, thanks. In the model I am interested in there would be some latent variables outside the independent GPs. So I should add something like this?

def model(Xs, Ys, jitter=1e-6):
    N, n_s = Xs.shape
    
    var = numpyro.sample("var", dist.LogNormal(0.0,10.0))
    length = numpyro.sample("length", dist.LogNormal(0.0,10.0))
    noise = numpyro.sample("noise", dist.LogNormal(0.0,10.0))
    for i in range(n_s):
        x = Xs[:,i]
        y = Ys[:,i]
        K = rbf_covariance(var, length, noise, x.reshape(-1,1), x.reshape(-1,1), jitter=jitter)
        
        numpyro.sample(
            f"Ys[:,{i}]",
            dist.MultivariateNormal(loc=np.zeros(K.shape[0]), covariance_matrix=K),
            obs=y
        )

Would this treat the conditioning for each i separate? I.e. given var, length and noise would the MultivariateNormal draws be independent but just get different inputs and outputs?

yes (at least assuming you’ve computed K correctly)

Thank you for all your help! I think I figured out my issue after reading this and the links here. I wasn’t understanding how the broadcasting was being done.