yes, sure! This is a minimal and incomplete version of the current model, as it would look like if I could fit everything in memory.
Maybe I should also add that x could be different points for different datasets.
import torch
import gpytorch
import pyro
class GPFA(gpytorch.models.ApproximateGP):
def __init__(self):
self.n_features = 50
self.n_datasets = 5
self.n_latent_factors = 3
inducing_points = torch.rand([self.n_datasets, self.n_latent_factors, 100, 2])
var_dist = gpytorch.variational.CholeskyVariationalDistribution(
num_inducing_points=inducing_points.shape[-2],
batch_shape=torch.Size([self.n_datasets, self.n_latent_factors])
)
variational_strategy = gpytorch.variational.VariationalStrategy(
model=self,
inducing_points=inducing_points,
variational_distribution=var_dist,
)
super().__init__(variational_strategy=variational_strategy)
self.mean_module = gpytorch.means.ZeroMean(
batch_shape=torch.Size([self.n_datasets, self.n_latent_factors])
)
self.covar_module = gpytorch.kernels.ScaleKernel(
gpytorch.kernels.RBFKernel(
batch_shape=torch.Size([self.n_datasets, self.n_latent_factors])
),
batch_shape=torch.Size([self.n_datasets, self.n_latent_factors]),
)
def forward(self, x):
mean = self.mean_module(x)
covar = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean, covar)
def model(self, x, y):
pyro.module('gp', self)
sample_plate = pyro.plate('sample_plate', dim=-1)
latent_factors_plate = pyro.plate('latent_factors_plate', dim=-2)
datasets_plate = pyro.plate('datasets_plate', dim=-3)
features_plate = pyro.plate('features_plate', dim=-4)
# factor loadings
with features_plate, latent_factors_plate:
w = pyro.sample('w', pyro.distributions.Normal(0, 1))
# latent factors
with datasets_plate, latent_factors_plate, sample_plate:
z = pyro.sample('z', self.pyro_model(x))
# ...
def guide(self, x, y):
sample_plate = pyro.plate('sample_plate', dim=-1)
latent_factors_plate = pyro.plate('latent_factors_plate', dim=-2)
datasets_plate = pyro.plate('datasets_plate', dim=-3)
features_plate = pyro.plate('features_plate', dim=-4)
w_loc = pyro.param(
name='w_loc',
init_tensor=torch.zeros([self.n_features, 1, self.n_latent_factors, 1]),
)
w_loc = pyro.param(
name='w_loc',
init_tensor=torch.ones([self.n_features, 1, self.n_latent_factors, 1]),
constraint=pyro.distributions.constraints.positive,
)
with features_plate, latent_factors_plate:
w = pyro.sample('w', pyro.distributions.Normal(0, 1))
with datasets_plate, latent_factors_plate, sample_plate:
z = pyro.sample('z', self.pyro_guide(x))
# ...