Hello pyro people
I’m trying to implement an empirical Bayes prior via a latent Gaussian process, and struggling a bit with how to structure it…
I have code that works without the empirical Bayes part (via pyro.sample
calls). I’ve now been able to add a GP and get it to work, or at least to run. However, I can’t access any of the values learned for the GP mean and variance, and so I can’t see whether it’s doing anything sensible. I’ve posted my current model at the bottom.
The main thing I want to do is to access the mean and variance values for the GP prior at the various values of X.
I suspect that what I need to do is convert the whole thing to a slightly more sophisticated structure, where I define the model as a class, with model
and guide
methods. This discussion looks relevant, but the discussion there is basically about “linear NN + GP”, whereas what I would like to do is “GP + standard pyro sample calls”.
Things I’m confused about:
- How to define a guide for the GP and also for the non-GP parts (e.g. I can’t see how to use
AutoDiagonalNormal
for the non-GP parts). - How to define the
predictive
function for such a model.
A couple of further comments, in case it makes a difference to choice of structure:
- At some point I’d like to use a
SparseGPRegression
model (currently I’m running for 100 features, but in future I would like to do more). - I’d also like to reimplement in
NumPyro
for speed benefits. - I don’t care which implementation of GP I use I’d be happy with anything that works and is reasonably fast (e.g.
tinygp
?).
I hope this all makes sense!
Thanks,
Will
Model definition
def mdl_fixed_gammapoisson_gp(y_counts, x_fix, cols_fix, mean_int, logcpms):
# some setup
pyro.enable_validation(True)
min_value = torch.finfo(torch.float).eps
max_value = torch.finfo(torch.float).max
n_obs = y_counts.shape[0]
n_gs = y_counts.shape[1]
n_fix = x_fix.shape[1]
# define prior for inv phi
phi_sd = pyro.sample("phi_sd", dist.HalfCauchy(1.0))
phi_len = pyro.sample("phi_len", dist.LogNormal(0.0, 1.0))
# define prior for inv phi
phi_gp = gp.models.GPRegression(
logcpms, None,
gp.kernels.RBF(1),
noise = torch.tensor(1e-3),
mean_function = lambda x: x
)
phi_gp.kernel.variance = phi_sd
phi_gp.kernel.lengthscale = phi_len
# sample overdispersion parameters
with pyro.plate("log_inv_phi_genes", n_gs):
phi_loc, phi_var = phi_gp.model()
log_inv_phi = pyro.sample("log_inv_phi", dist.Normal(phi_loc, phi_var))
phi = pyro.deterministic("phi",
torch.exp(-log_inv_phi).clamp(min=min_value, max=max_value))
# sample from intercepts
with pyro.plate("intcpt_genes", n_gs):
prior_int = dist.Normal(mean_int, 2.0 * torch.ones(n_gs))
# prior_int = dist.Uniform(mean_int - 0.1, mean_int + 0.1)
beta_intercept = pyro.sample("beta_intercept", prior_int)
# sample from sigmas
with pyro.plate("sigma_fix_genes", n_gs):
dist_sd_fix = dist.HalfNormal(scale=5.0)
sd_fix = pyro.sample("sigma_fix", dist_sd_fix).clamp(min=min_value)
# sample from fixed values
with pyro.plate("beta_fix_genes", n_gs):
prior_fix = dist.Normal(0.0, sd_fix)
with pyro.plate("beta_fix_fix", n_fix):
beta_fix = pyro.sample("beta_fix", prior_fix)
# check that outputs are ok
assert beta_fix.shape[0] == n_fix
assert beta_fix.shape[1] == n_gs
# add these all together, exponentiate
mu_mat = beta_intercept + torch.mm(x_fix, beta_fix)
means = torch.exp(mu_mat).clamp(min=min_value, max=max_value)
# transform to gamma poisson parameters
alpha = phi
beta = phi / means
# draw samples from Poisson and condition them on our data
with pyro.plate("data_genes", n_gs):
with pyro.plate("data_obs", y_counts.shape[0]):
post_dist = dist.GammaPoisson(alpha, beta)
obs = pyro.sample("obs", post_dist, obs=y_counts)
Then I set up a guide like this:
guide = AutoDiagonalNormal(mdl_fixed_gammapoisson_gp)
my_svi = SVI(
model = mdl_fixed_gammapoisson_gp,
guide = guide,
optim = ClippedAdam({"lr": 0.01, 'clip_norm': 1.0}),
loss = Trace_ELBO()
)