Mini batching with Bayesian GPLVM

Trying to define mini-batch logic for Bayesian GPLVM training but unsuccessful so far following the suggestions in this older thread: Pyro Bayesian GPLVM SVI with minibatching

So the suggestion in this thread is to use:

X_minibatch = pyro.sample(…, dist.Normal(x_loc[minibatch_indices], x_scale[minibatch_indices]))
y_minibatch = y[minibatch_indices]
self.base_model.set_data(X_minibatch, y_minibatch)

(…in each epoch before calling svi.step())

However, this causes the learnt X representation to be completely out of whack…- any ideas?

It is almost impossible to use Bayesian GPLVM for any meaningfully large datasets without minibatching - it should be fundamental to SVI.

A quick note on this - we’re also having to set:
(self/gplvm).X = X_minibatch too
as, without this, pyro complains that the dimensions don’t match up. Perhaps this line is interfering with the fact that self.X is not an ordinary tensor

it should be fundamental to SVI

I guess so. But we need to define the model properly. Could you provide your code for mini-batch gplvm? Without it, it is hard to figure out why things go wrong with your code.

Back to a year ago or so, GPLVM is implemented just to show how simple it is to define such model using GP module. In GPLVM tutorial, we even didn’t use that helper class. As mentioned in that forum thread, SVGP already supports minibatch. I guess all we need is to rewrite GPLVM a bit to support minibatch. I’m not sure… because I don’t know what mini-batch gplvm does under the hood (what is data or latent variable for each step, how likelihood is scaled,…) Some information about them would be very helpful.

Thanks for writing back…so as a first cut I am trying to do regression with SVGP class and mini-batching(below).

This is straightforward as X is fixed and given so one just needs to slice X to obtain a mini-batch. The problem with GPLVM is that X is learnt …it doesn’t seem right to put pyro.sample statement inside the training for loop…

N = 5000
X = dist.Uniform(0.0, 5.0).sample(sample_shape=(N,))
X_test = torch.linspace(0.0,5.0)
y = 0.5 * torch.sin(3*X) + dist.Normal(0.0, 0.05).sample(sample_shape=(N,))

Xu = torch.arange(20.) / 2

# initialize the kernel and model
pyro.clear_param_store()
kernel = gp.kernels.RBF(input_dim=1)
likelihood = gp.likelihoods.Gaussian()
svgp = gp.models.VariationalSparseGP(X, y, kernel, Xu, likelihood)
svi = SVI(svgp.model, svgp.guide, optimizer, Trace_ELBO())

optimizer = pyro.optim.Adam({"lr": 0.01})
#loss_fn = pyro.infer.Trace_ELBO().differentiable_loss
losses = np.zeros(num_steps)
num_steps = 2000
bar = tqdm(range(num_steps))
for i in  bar:
    idx = get_batch_idx(y, 5)
    X_batch = X[idx]
    y_batch = y[idx]
    svgp.set_data(X_batch, y_batch)
    losses[i] = svi.step()
    bar.set_description(str(int(losses[i])))

@vr308 When you have minibatch idx, if GPLVM is not flexible enough, you can do something like

def model(idx):
    X_minibatch = pyro.sample("X", dist.Normal(0, 1).expand([len(idx)]))  # prior
    y_minibatch = y[idx]
    svgp.set_data(X_minibatch, y_minibatch)
    svgp.model()

def guide(idx):
    x_loc = pyro.param("x_loc", torch.zeros(N))
    x_scale = pyro.param("x_scale", torch.ones(N), constraint=constraints.positive)
    X_minibatch = pyro.sample("X", dist.Normal(x_loc[idx], x_scale[idx]))  # guide
    y_minibatch = y[idx]
    svgp.set_data(X_minibatch, y_minibatch)
    svgp.guide()

I don’t know why you observed that x_loc is completely out of whack with that approach. What should we do for mini-batch GPLVM (my difficulty is to understand what we should do in theory, not in implementation - because I think the implementation should be straightforward and I can help in implementing it)? What’s missing in the above pair of model/guide? Do we need to scale the likelihood or something?

Ok - thanks for the pointer. Below is my implementation fo GPLVM with batching …using set_data and the SparseGPR class. When I dont use batching / set_data I am able to reproduce a reasonable 2d latent construction (this reproduces what is in the original Bayesian GPLVM paper (Titsias, 2010)). But with batching …it is a different story… no matter how long you train…it isn’t able to learn. My best guess is a bug in set_data. good bad

def get_batch_idx(Y, batch_size):
        N = len(Y)
        valid_indices = np.array(range(N))
        batch_indices = np.random.choice(valid_indices,size=batch_size,replace=False)
        return batch_indices
        
# Loading training and test data

dataset_name = 'oilflow'
test_size = 100
n, d, q, X, Y, lb = load_real_data(dataset_name)
Y = Y @ Y.std(axis=0).diag().inverse()

q = 2

# initialize the kernel and model
pyro.clear_param_store()
kernel = gp.kernels.RBF(input_dim=q, lengthscale=torch.ones(q))
likelihood = gp.likelihoods.Gaussian()

latent_prior_mean = torch.zeros(Y.size(0), q)
X_init = float_tensor(PCA(q).fit_transform(Y))
X = torch.nn.Parameter(X_init) 

Xu = float_tensor(np.random.normal(size=(25, q)))

optimizer = pyro.optim.Adam({"lr": 0.01})
svgp = gp.models.SparseGPRegression(X, Y.T, kernel, Xu, jitter=1e-4)

### training with SVI style syntax
svi = SVI(svgp.model, svgp.guide, optimizer, Trace_ELBO())

svgp.X = pyro.nn.PyroSample(dist.Normal(X_init, 1))
svgp.autoguide('X', dist.Normal)

num_steps = 2000
losses = np.zeros(num_steps)
bar = tqdm(range(num_steps))
for i in  bar:
    idx = get_batch_idx(Y, 1000)
    #X_batch = X[idx]
    #y_batch = Y[idx]
    #svgp.set_data(X_batch, y_batch.T)
    losses[i] = svi.step(idx)
    bar.set_description(str(int(losses[i])))
    
X = svgp.X_loc.detach()
plt.figure()
plt.scatter(X[:,0], X[:,1], c=lb)

Could you write a psuedo code for minibatch gplvm? I don’t know if the implementation in my last comment does what is called “minibatch gplvm”. If it is correct in principle, then I also don’t know what’s wrong with set_data. Does it set wrong values for gp or something?
If I understand correctly, then we need to use pyro.poutine.scale to scale the prior/guide of X by num_data/batch_size, use VSGP class with num_data is set to full data size (so we have unbiased estimate of the full data likelihood). SGPR does not supports minibatch. Those are suggestions based on my understanding of minibatch gplvm. Does that make sense to you? (Or those steps are unnecessary?) Again, my main difficulty is to understand what you want. If you understand well how minibatch gplvm works, then pseudo code would be very helpful. I can help you implement it in Pyro. If after that, the code still does not work, then probably there is a bug somewhere, or minibatch gplvm just not work… At least your last code does not seem to be correct to me: missing scales, using non mini-batch gp model so there are many rooms here for fixings, improvings,…

Let me get you some pseudo-code shortly. thanks

the gplvm has two types of latent variables: global ones (u) that are used in the sparse gp prior and local ones (x_i) that are the latent inputs for each data point. the associated log_probs need to be scaled differently in the loss, so if something is going wrong i imagine it might be there.

Yeah, I think we need it to make the inference works. @vr308 Could you try to add scale handler for X_minibatch in both model and guide

with pyro.poutine.scale(scale=X.size(0) / batch_size):
    X_minibatch = pyro.sample("X", ...)

and use svgp = VariationalSparseGP(..., num_data=X.size(0))? You might also want to use TraceMeanField_ELBO like in dkl example if the training process is highly unstable.

So I re-wrote a bit to resemble the model / guide syntax:

def model(idx):
    with pyro.poutine.scale(scale=X.size(0) / len(idx)):
        X_minibatch = pyro.sample("X", dist.Normal(torch.zeros(len(idx),q), 1))  # prior
        y_minibatch = Y[idx]
        svgp.set_data(X_minibatch, y_minibatch.T)
        svgp.model()

def guide(idx):  
    x_loc = pyro.param("x_loc", torch.zeros(1000,q))
    x_scale = pyro.param("x_scale", torch.ones(1000,q), constraint=constraints.positive)
    with pyro.poutine.scale(scale=X.size(0) / len(idx)):
        X_minibatch = pyro.sample("X", dist.Normal(x_loc[idx], x_scale[idx]))  # guide
        y_minibatch = Y[idx]
        svgp.set_data(X_minibatch, y_minibatch.T)
        svgp.guide()  

Let me know if you see anything suspicious here…

dataset_name = 'oilflow'
test_size = 100
n, d, q, X, Y, lb = load_real_data(dataset_name)
Y = Y @ Y.std(axis=0).diag().inverse()

q = 2

# initialize the kernel and model
pyro.clear_param_store()
kernel = gp.kernels.RBF(input_dim=q, lengthscale=torch.ones(q))
likelihood = gp.likelihoods.Gaussian()

latent_prior_mean = torch.zeros(Y.size(0), q)
X_init = float_tensor(PCA(q).fit_transform(Y))
X = torch.nn.Parameter(X_init) 

Xu = float_tensor(np.random.normal(size=(25, q)))

optimizer = pyro.optim.Adam({"lr": 0.01})
svgp = gp.models.VariationalSparseGP(X, Y.T, kernel, Xu, likelihood, num_data=X.size(0))

### training with SVI style syntax
svi = SVI(model, guide, optimizer, Trace_ELBO())

svgp.X = pyro.nn.PyroSample(dist.Normal(X_init, 1))
svgp.autoguide('X', dist.Normal)

num_steps = 5000
losses = np.zeros(num_steps)
bar = tqdm(range(num_steps))
for i in  bar:
    idx = get_batch_idx(Y, 100)
    losses[i] = svi.step(idx)
    bar.set_description(str(int(losses[i])))

While it runs without errors… I believe it doesn’t recover the batch-less version when I set the idx size to 1000 (so using the full dataset at each step but with set_data).

All of those are not needed because they are encoded in the model/guide pair (unless you want to run SVI over svgp.model/svgp.guide as in gplvm example). Also, you should only scale X, i.e. move svgp.model()/svgp.guide() out of scale handler, and make sure that 1000 is the size of your full dataset:

x_loc = pyro.param("x_loc", torch.zeros(1000,q))

Great - I believe this is working now… yep 1000 is the size of my full data -> so that can be replaced with X.size(0).

# Loading training and test data

dataset_name = 'oilflow'
test_size = 100
n, d, q, X, Y, lb = load_real_data(dataset_name)
Y = Y @ Y.std(axis=0).diag().inverse()

q = 2

# initialize the kernel and model
pyro.clear_param_store()
kernel = gp.kernels.RBF(input_dim=q, lengthscale=torch.ones(q))
likelihood = gp.likelihoods.Gaussian()

latent_prior_mean = torch.zeros(Y.size(0), q)
X_init = float_tensor(PCA(q).fit_transform(Y))
X = X_init

optimizer = pyro.optim.Adam({"lr": 0.01})
svgp = gp.models.VariationalSparseGP(X, Y.T, kernel, Xu, likelihood, num_data=X_init.size(0))

### training with SVI style syntax
svi = SVI(model, guide, optimizer, Trace_ELBO())

num_steps = 5000
losses = np.zeros(num_steps)
bar = tqdm(range(num_steps))
for i in  bar:
    idx = get_batch_idx(Y, 100)
    losses[i] = svi.step(idx)
    bar.set_description(str(int(losses[i])))

After running this I extract the learnt latent X’s with:

X = svgp.X.detach() <- the size of X here is that of the mini-batch
plt.figure()
plt.scatter(X[:,0], X[:,1], c=lb)

How can I retrieve / persist the full X after training? Thanks.

All information about X is available in params x_loc, x_scale. You can use pyro.param("x_loc") to get what you want.

The probabilistic model is: take a batch of xloc, xscale, sample x from Normal(xloc, xscale), put that sample to gp with the corresponding batch of y. The code shows exactly that. :slight_smile: Also, if you want to optimize the performance, use TraceMeanField_ELBO. It is recommended for GP models.