Message passing between Pyro models

Hi,

I have a Pyro model which is trained with SVI. It has a matrix as a latent variable. Now I have multiple datasets and train one model instance on each of them. I expect the matrix variable to be the same for every dataset, up to dropouts of single rows. Because of an additional degree of freedom in the model, the learned rows couId also be permuted. I wonder whether it is possible to do some sort of message passing at training time between the models, to obtain one global matrix parameter. I can’t create one big model for all datasets, because it would use too much memory. Does anyone have an idea?

can you fit all the data in memory at once?

yes, I think that should be fine. Most of the memory usage comes from the model.

well you didn’t share any details about your model but the natural thing to do might be to create a hierarchical model of some kind. during training you only subsample one dataset at a time

Sorry, I should have described the model in my first post. It is similar to Gaussian process factor analysis. The factor loadings matrix should be the same for all datasets, but I want to get different sets of latent factors for the different datasets. The latent factors (batched GPs in GPyTorch) are memory hungry, so I can’t train a model with a latent space of shape [number of latent factors, number of datasets]. If I understand correctly, you propose to subsample the datasets dimension (?). I’m not sure how this would look like in this case, but maybe this is a question that is rather related to GpyTorch than to Pyro…?

can you give more details? are you using sparse GPs with inducing points?

yes, sure! This is a minimal and incomplete version of the current model, as it would look like if I could fit everything in memory.

Maybe I should also add that x could be different points for different datasets.

import torch
import gpytorch
import pyro

class GPFA(gpytorch.models.ApproximateGP):
    def __init__(self):
        self.n_features = 50
        self.n_datasets = 5
        self.n_latent_factors = 3

        inducing_points = torch.rand([self.n_datasets, self.n_latent_factors, 100, 2])

        var_dist = gpytorch.variational.CholeskyVariationalDistribution(
            num_inducing_points=inducing_points.shape[-2],
            batch_shape=torch.Size([self.n_datasets, self.n_latent_factors])
        )

        variational_strategy = gpytorch.variational.VariationalStrategy(
            model=self,
            inducing_points=inducing_points,
            variational_distribution=var_dist,
        )

        super().__init__(variational_strategy=variational_strategy)

        self.mean_module = gpytorch.means.ZeroMean(
            batch_shape=torch.Size([self.n_datasets, self.n_latent_factors])
        )

        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(
                batch_shape=torch.Size([self.n_datasets, self.n_latent_factors])
            ),
            batch_shape=torch.Size([self.n_datasets, self.n_latent_factors]),
        )

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)

    def model(self, x, y):
        pyro.module('gp', self)

        sample_plate = pyro.plate('sample_plate', dim=-1)
        latent_factors_plate = pyro.plate('latent_factors_plate', dim=-2)
        datasets_plate = pyro.plate('datasets_plate', dim=-3)
        features_plate = pyro.plate('features_plate', dim=-4)

        # factor loadings
        with features_plate, latent_factors_plate:
            w = pyro.sample('w', pyro.distributions.Normal(0, 1))

        # latent factors
        with datasets_plate, latent_factors_plate, sample_plate:
            z = pyro.sample('z', self.pyro_model(x))

        # ...

    def guide(self, x, y):
        sample_plate = pyro.plate('sample_plate', dim=-1)
        latent_factors_plate = pyro.plate('latent_factors_plate', dim=-2)
        datasets_plate = pyro.plate('datasets_plate', dim=-3)
        features_plate = pyro.plate('features_plate', dim=-4)

        w_loc = pyro.param(
            name='w_loc',
            init_tensor=torch.zeros([self.n_features, 1, self.n_latent_factors, 1]),
        )
    
        w_loc = pyro.param(
            name='w_loc',
            init_tensor=torch.ones([self.n_features, 1, self.n_latent_factors, 1]),
            constraint=pyro.distributions.constraints.positive,
        )

        with features_plate, latent_factors_plate:
            w = pyro.sample('w', pyro.distributions.Normal(0, 1))

        with datasets_plate, latent_factors_plate, sample_plate:
            z = pyro.sample('z', self.pyro_guide(x))

        # ...