Custom vs Autoguide in linear Latent Variable Model

Hi,
I am building a simple LV model: Given data D, I want to find a factorization D = S * z where z is my latent variable of dimension K and fallows a Multivariate Normal distribution. 2 questions came up:

  1. I face a very different learning behavior between (A) providing a guide and (B) using the AutoNormal auto-guide. The former is very unstable whereas the latter shows a smooth loss curve. Can somebody elaborate where the difference is coming from? (Btw, I already iterated over the learning rate)
guide = pyro.infer.autoguide.AutoNormal(model)

vs.

def guide(data):
    loc_z_q0   = pyro.param("loc_z_q0", torch.zeros(K))
    scale_z_q0 = pyro.param("scale_z_q0", torch.eye(K, K), constraint=constraints.positive)
    z0         = pyro.sample("z0", dist.MultivariateNormal(loc_z_q0, scale_z_q0))

    return {"z0": z0}

def model(data):
    """Model for data = s @ z + sigma."""
    s = pyro.param("s", torch.randn(G, K))

    z0_loc   = pyro.param("z0_loc", torch.zeros(K))
    z0_scale = pyro.param("z0_scale", torch.eye(K, K), constraint=constraints.positive)
    z0       = pyro.sample("z0", dist.MultivariateNormal(z0_loc, z0_scale))
    sigma0 = pyro.param("sigma0", torch.eye(G, G), constraint=constraints.positive)

    mean = s @ z0
    mean = mean.squeeze()  # Code blows up otherwise

    for i in range(len(data)):
        pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mean, sigma0), obs=data[:, i])
  1. Even after a long training and a converged loss function, the params in pyro.param("AutoNormal.locs.z0") and pyro.param("AutoNormal.scales.z0") are both incorrect, i.e. they don’t represent a MVN(0,1) distribution.
    Moreover, pyro.param("AutoNormal.scales.z0") has a weird shape: only K, but pyro.param(“z0_scale”) has K x K.
    Hence, I guess something is systematically wrong in my code?

Minimum Working Example:


import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

pyro.set_rng_seed(4321)
pyro.clear_param_store()

# Data
G = 20    # Features
N = 1000  # Samples
K = 5     # Latent Dimensions
s_true = torch.from_numpy(np.random.randint(0, 3, size=(G, K))).float()
z_true = torch.from_numpy(np.random.multivariate_normal(np.zeros(K), np.eye(K), size=N).T).float()
data = s_true @ z_true


def model(data):
    s = pyro.param("s", torch.randn(G, K))

    z0_loc   = pyro.param("z0_loc", torch.zeros(K))
    z0_scale = pyro.param("z0_scale", torch.eye(K, K), constraint=constraints.positive)
    z0       = pyro.sample("z0", dist.MultivariateNormal(z0_loc, z0_scale))
    sigma0 = pyro.param("sigma0", torch.eye(G, G), constraint=constraints.positive)

    mean = s @ z0
    mean = mean.squeeze()

    for i in range(len(data)):
        pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mean, sigma0), obs=data[:, i])


def guide(data):
    loc_z_q0   = pyro.param("loc_z_q0", torch.zeros(K))
    scale_z_q0 = pyro.param("scale_z_q0", torch.eye(K, K), constraint=constraints.positive)
    z0         = pyro.sample("z0", dist.MultivariateNormal(loc_z_q0, scale_z_q0))

    return {"z0": z0}

# OR
# guide = pyro.infer.autoguide.AutoNormal(model)

# setup the optimizer
adam_params = {"lr": 0.0005}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

steps = 12000
losses = np.zeros(steps)
for i in range(steps):
    losses[i] = svi.step(data)
    if i % 100 == 0:
        print(f"[iteration {i}] loss: {losses[i]:.4f}")

Thanks in advance!

this is not a valid constraint for positive definite matrices. it simply enforces positive entries. you want constraints.positive_definite

1 Like

Thanks a bunch! That makes the loss function much more stable.
However, I observe that the code now tends to raise

ValueError: Expected parameter covariance_matrix (Tensor of shape (50, 50)) of distribution MultivariateNormal(loc: torch.Size([50]), covariance_matrix: torch.Size([50, 50])) to satisfy the constraint PositiveDefinite(), ...

quite often. It seems to be related to the learning rate. Is this expected to happen frequently or what are the common causes for this?

this is expected to happen if e.g. your learning rate is too high. using 64 bit precision can also mitigate. see this tutorial for other generic tips

1 Like