I’m at a loss here. I’ve written a sizeable amount of code with the pyro GPLVM at its core, below is a (minimalish) reproducing distillation of that code. I’ve checked that this can be reproduced on two machines.
The issue is this: I’m trying to fit a GPLVM with a normal variational distribution. If I initialise the mean parameter with the sklearn PCA, the parameter doesn’t change. If I add some normal noise to this using the +=
operator, the parameter doesn’t change after training. It isn’t that this parameter is initialized at an optimium obviously, as the normal error would have made the param change.
The parameter does change if initialised with np.random.normal
, and strangely even if initialised
at PCA, if I add normal noise using mu_pca = mu_pca + error
, everything’s fine.
What’s going on? Am I missing something very obvious?
from uuid import uuid4
import pyro, torch
import numpy as np
import pyro.contrib.gp as gp
import pyro.distributions as dist
from sklearn.decomposition import PCA
def float_tensor(X): return torch.tensor(X).float()
class VarParams(torch.nn.Module):
def __init__(self, Y, q):
super().__init__()
self.flows = [dist.transforms.Planar(q)]
# ------------------> PROBLEM
mu = PCA(q).fit_transform(Y) # DOESN'T WORK
# mu = mu + np.random.normal(scale=100, size=Y.shape) # WORKS
# mu += np.random.normal(scale=100, size=Y.shape) # DOESN'T
# mu = np.random.normal(scale=100, size=Y.shape) # WORKS
# <---------------------------
self.mu = torch.nn.Parameter(float_tensor(mu))
self._log_sigma = torch.nn.Parameter(torch.zeros(self.mu.shape).float())
self.update_parameters()
def update_parameters(self):
self.sigma = self._log_sigma.exp()
self.base_dist = dist.Normal(self.mu, self.sigma)
self.flow_dist = dist.TransformedDistribution(self.base_dist, self.flows)
n = 1000; d = q = 2
Y = float_tensor(np.random.normal(size=(n, d)))
gp_module = gp.models.SparseGPRegression(
X=float_tensor(PCA(q).fit_transform(Y)),
y=Y.T,
kernel=gp.kernels.Matern52(q, lengthscale=torch.ones(q)),
Xu=float_tensor(np.random.normal(size=(25, q))),
jitter=1e-4)
gplvm = gp.models.GPLVM(gp_module)
var_params = VarParams(Y, q)
bef = var_params.mu.detach().clone()
def guide():
id = str(uuid4())
pyro.module(id + 'flow', var_params.flows[0])
pyro.module(id + 'mu_sig', var_params)
return pyro.sample('X', var_params.flow_dist)
var_params.update_parameters()
svi = pyro.infer.SVI(
model=gplvm.model,
guide=guide,
optim=pyro.optim.Adam(dict(lr=0.01)),
loss=pyro.infer.Trace_ELBO())
for step in range(10):
var_params.update_parameters()
svi.step()
aft = var_params.mu.detach().clone()
print(aft - bef) # shouldn't be zeros!!