Combining GP and pyro sample

Hello pyro people

I’m trying to implement an empirical Bayes prior via a latent Gaussian process, and struggling a bit with how to structure it…

I have code that works without the empirical Bayes part (via pyro.sample calls). I’ve now been able to add a GP and get it to work, or at least to run. However, I can’t access any of the values learned for the GP mean and variance, and so I can’t see whether it’s doing anything sensible. I’ve posted my current model at the bottom.

The main thing I want to do is to access the mean and variance values for the GP prior at the various values of X.

I suspect that what I need to do is convert the whole thing to a slightly more sophisticated structure, where I define the model as a class, with model and guide methods. This discussion looks relevant, but the discussion there is basically about “linear NN + GP”, whereas what I would like to do is “GP + standard pyro sample calls”.

Things I’m confused about:

  • How to define a guide for the GP and also for the non-GP parts (e.g. I can’t see how to use AutoDiagonalNormal for the non-GP parts).
  • How to define the predictive function for such a model.

A couple of further comments, in case it makes a difference to choice of structure:

  • At some point I’d like to use a SparseGPRegression model (currently I’m running for 100 features, but in future I would like to do more).
  • I’d also like to reimplement in NumPyro for speed benefits.
  • I don’t care which implementation of GP I use :slight_smile: I’d be happy with anything that works and is reasonably fast (e.g. tinygp?).

I hope this all makes sense!

Thanks,
Will

Model definition

def mdl_fixed_gammapoisson_gp(y_counts, x_fix, cols_fix, mean_int, logcpms):
  # some setup
  pyro.enable_validation(True)
  min_value = torch.finfo(torch.float).eps
  max_value = torch.finfo(torch.float).max
  n_obs     = y_counts.shape[0]
  n_gs      = y_counts.shape[1]
  n_fix     = x_fix.shape[1]

  # define prior for inv phi
  phi_sd    = pyro.sample("phi_sd", dist.HalfCauchy(1.0))
  phi_len   = pyro.sample("phi_len", dist.LogNormal(0.0, 1.0))

  # define prior for inv phi
  phi_gp    = gp.models.GPRegression(
    logcpms, None, 
    gp.kernels.RBF(1), 
    noise = torch.tensor(1e-3),
    mean_function = lambda x: x
  )
  phi_gp.kernel.variance    = phi_sd
  phi_gp.kernel.lengthscale = phi_len

  # sample overdispersion parameters
  with pyro.plate("log_inv_phi_genes", n_gs):
    phi_loc, phi_var  = phi_gp.model()
    log_inv_phi       = pyro.sample("log_inv_phi", dist.Normal(phi_loc, phi_var))
    phi               = pyro.deterministic("phi", 
      torch.exp(-log_inv_phi).clamp(min=min_value, max=max_value))

  # sample from intercepts
  with pyro.plate("intcpt_genes", n_gs):
    prior_int       = dist.Normal(mean_int, 2.0 * torch.ones(n_gs))
    # prior_int       = dist.Uniform(mean_int - 0.1, mean_int + 0.1)
    beta_intercept  = pyro.sample("beta_intercept", prior_int)

  # sample from sigmas
  with pyro.plate("sigma_fix_genes", n_gs):
    dist_sd_fix = dist.HalfNormal(scale=5.0)
    sd_fix      = pyro.sample("sigma_fix", dist_sd_fix).clamp(min=min_value)

  # sample from fixed values
  with pyro.plate("beta_fix_genes", n_gs):
    prior_fix = dist.Normal(0.0, sd_fix)
    with pyro.plate("beta_fix_fix", n_fix):
      beta_fix  = pyro.sample("beta_fix", prior_fix)

  # check that outputs are ok
  assert beta_fix.shape[0] == n_fix
  assert beta_fix.shape[1] == n_gs

  # add these all together, exponentiate
  mu_mat    = beta_intercept + torch.mm(x_fix, beta_fix)
  means     = torch.exp(mu_mat).clamp(min=min_value, max=max_value)

  # transform to gamma poisson parameters
  alpha     = phi
  beta      = phi / means

  # draw samples from Poisson and condition them on our data
  with pyro.plate("data_genes", n_gs):
    with pyro.plate("data_obs", y_counts.shape[0]):
      post_dist = dist.GammaPoisson(alpha, beta)
      obs       = pyro.sample("obs", post_dist, obs=y_counts)

Then I set up a guide like this:

  guide   = AutoDiagonalNormal(mdl_fixed_gammapoisson_gp)
  my_svi  = SVI(
    model   = mdl_fixed_gammapoisson_gp, 
    guide   = guide,
    optim   = ClippedAdam({"lr": 0.01, 'clip_norm': 1.0}),
    loss    = Trace_ELBO()
    )