Hi,
I am building a simple LV model: Given data D
, I want to find a factorization D = S * z
where z is my latent variable of dimension K
and fallows a Multivariate Normal distribution. 2 questions came up:
- I face a very different learning behavior between (A) providing a guide and (B) using the
AutoNormal
auto-guide. The former is very unstable whereas the latter shows a smooth loss curve. Can somebody elaborate where the difference is coming from? (Btw, I already iterated over the learning rate)
guide = pyro.infer.autoguide.AutoNormal(model)
vs.
def guide(data):
loc_z_q0 = pyro.param("loc_z_q0", torch.zeros(K))
scale_z_q0 = pyro.param("scale_z_q0", torch.eye(K, K), constraint=constraints.positive)
z0 = pyro.sample("z0", dist.MultivariateNormal(loc_z_q0, scale_z_q0))
return {"z0": z0}
def model(data):
"""Model for data = s @ z + sigma."""
s = pyro.param("s", torch.randn(G, K))
z0_loc = pyro.param("z0_loc", torch.zeros(K))
z0_scale = pyro.param("z0_scale", torch.eye(K, K), constraint=constraints.positive)
z0 = pyro.sample("z0", dist.MultivariateNormal(z0_loc, z0_scale))
sigma0 = pyro.param("sigma0", torch.eye(G, G), constraint=constraints.positive)
mean = s @ z0
mean = mean.squeeze() # Code blows up otherwise
for i in range(len(data)):
pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mean, sigma0), obs=data[:, i])
- Even after a long training and a converged loss function, the params in
pyro.param("AutoNormal.locs.z0")
andpyro.param("AutoNormal.scales.z0")
are both incorrect, i.e. they don’t represent a MVN(0,1) distribution.
Moreover,pyro.param("AutoNormal.scales.z0")
has a weird shape: onlyK
, but pyro.param(“z0_scale”) hasK x K
.
Hence, I guess something is systematically wrong in my code?
Minimum Working Example:
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch
pyro.set_rng_seed(4321)
pyro.clear_param_store()
# Data
G = 20 # Features
N = 1000 # Samples
K = 5 # Latent Dimensions
s_true = torch.from_numpy(np.random.randint(0, 3, size=(G, K))).float()
z_true = torch.from_numpy(np.random.multivariate_normal(np.zeros(K), np.eye(K), size=N).T).float()
data = s_true @ z_true
def model(data):
s = pyro.param("s", torch.randn(G, K))
z0_loc = pyro.param("z0_loc", torch.zeros(K))
z0_scale = pyro.param("z0_scale", torch.eye(K, K), constraint=constraints.positive)
z0 = pyro.sample("z0", dist.MultivariateNormal(z0_loc, z0_scale))
sigma0 = pyro.param("sigma0", torch.eye(G, G), constraint=constraints.positive)
mean = s @ z0
mean = mean.squeeze()
for i in range(len(data)):
pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mean, sigma0), obs=data[:, i])
def guide(data):
loc_z_q0 = pyro.param("loc_z_q0", torch.zeros(K))
scale_z_q0 = pyro.param("scale_z_q0", torch.eye(K, K), constraint=constraints.positive)
z0 = pyro.sample("z0", dist.MultivariateNormal(loc_z_q0, scale_z_q0))
return {"z0": z0}
# OR
# guide = pyro.infer.autoguide.AutoNormal(model)
# setup the optimizer
adam_params = {"lr": 0.0005}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
steps = 12000
losses = np.zeros(steps)
for i in range(steps):
losses[i] = svi.step(data)
if i % 100 == 0:
print(f"[iteration {i}] loss: {losses[i]:.4f}")
Thanks in advance!