# Combining GP and pyro sample

Hello pyro people

I’m trying to implement an empirical Bayes prior via a latent Gaussian process, and struggling a bit with how to structure it…

I have code that works without the empirical Bayes part (via `pyro.sample` calls). I’ve now been able to add a GP and get it to work, or at least to run. However, I can’t access any of the values learned for the GP mean and variance, and so I can’t see whether it’s doing anything sensible. I’ve posted my current model at the bottom.

The main thing I want to do is to access the mean and variance values for the GP prior at the various values of X.

I suspect that what I need to do is convert the whole thing to a slightly more sophisticated structure, where I define the model as a class, with `model` and `guide` methods. This discussion looks relevant, but the discussion there is basically about “linear NN + GP”, whereas what I would like to do is “GP + standard pyro sample calls”.

• How to define a guide for the GP and also for the non-GP parts (e.g. I can’t see how to use `AutoDiagonalNormal` for the non-GP parts).
• How to define the `predictive` function for such a model.

A couple of further comments, in case it makes a difference to choice of structure:

• At some point I’d like to use a `SparseGPRegression` model (currently I’m running for 100 features, but in future I would like to do more).
• I’d also like to reimplement in `NumPyro` for speed benefits.
• I don’t care which implementation of GP I use I’d be happy with anything that works and is reasonably fast (e.g. `tinygp`?).

I hope this all makes sense!

Thanks,
Will

### Model definition

``````def mdl_fixed_gammapoisson_gp(y_counts, x_fix, cols_fix, mean_int, logcpms):
# some setup
pyro.enable_validation(True)
min_value = torch.finfo(torch.float).eps
max_value = torch.finfo(torch.float).max
n_obs     = y_counts.shape[0]
n_gs      = y_counts.shape[1]
n_fix     = x_fix.shape[1]

# define prior for inv phi
phi_sd    = pyro.sample("phi_sd", dist.HalfCauchy(1.0))
phi_len   = pyro.sample("phi_len", dist.LogNormal(0.0, 1.0))

# define prior for inv phi
phi_gp    = gp.models.GPRegression(
logcpms, None,
gp.kernels.RBF(1),
noise = torch.tensor(1e-3),
mean_function = lambda x: x
)
phi_gp.kernel.variance    = phi_sd
phi_gp.kernel.lengthscale = phi_len

# sample overdispersion parameters
with pyro.plate("log_inv_phi_genes", n_gs):
phi_loc, phi_var  = phi_gp.model()
log_inv_phi       = pyro.sample("log_inv_phi", dist.Normal(phi_loc, phi_var))
phi               = pyro.deterministic("phi",
torch.exp(-log_inv_phi).clamp(min=min_value, max=max_value))

# sample from intercepts
with pyro.plate("intcpt_genes", n_gs):
prior_int       = dist.Normal(mean_int, 2.0 * torch.ones(n_gs))
# prior_int       = dist.Uniform(mean_int - 0.1, mean_int + 0.1)
beta_intercept  = pyro.sample("beta_intercept", prior_int)

# sample from sigmas
with pyro.plate("sigma_fix_genes", n_gs):
dist_sd_fix = dist.HalfNormal(scale=5.0)
sd_fix      = pyro.sample("sigma_fix", dist_sd_fix).clamp(min=min_value)

# sample from fixed values
with pyro.plate("beta_fix_genes", n_gs):
prior_fix = dist.Normal(0.0, sd_fix)
with pyro.plate("beta_fix_fix", n_fix):
beta_fix  = pyro.sample("beta_fix", prior_fix)

# check that outputs are ok
assert beta_fix.shape[0] == n_fix
assert beta_fix.shape[1] == n_gs

# add these all together, exponentiate
mu_mat    = beta_intercept + torch.mm(x_fix, beta_fix)
means     = torch.exp(mu_mat).clamp(min=min_value, max=max_value)

# transform to gamma poisson parameters
alpha     = phi
beta      = phi / means

# draw samples from Poisson and condition them on our data
with pyro.plate("data_genes", n_gs):
with pyro.plate("data_obs", y_counts.shape[0]):
post_dist = dist.GammaPoisson(alpha, beta)
obs       = pyro.sample("obs", post_dist, obs=y_counts)
``````

Then I set up a guide like this:

``````  guide   = AutoDiagonalNormal(mdl_fixed_gammapoisson_gp)
my_svi  = SVI(
model   = mdl_fixed_gammapoisson_gp,
guide   = guide,
optim   = ClippedAdam({"lr": 0.01, 'clip_norm': 1.0}),
loss    = Trace_ELBO()
)
``````