# Custom vs Autoguide in linear Latent Variable Model

Hi,
I am building a simple LV model: Given data `D`, I want to find a factorization `D = S * z` where z is my latent variable of dimension `K` and fallows a Multivariate Normal distribution. 2 questions came up:

1. I face a very different learning behavior between (A) providing a guide and (B) using the `AutoNormal` auto-guide. The former is very unstable whereas the latter shows a smooth loss curve. Can somebody elaborate where the difference is coming from? (Btw, I already iterated over the learning rate)
``````guide = pyro.infer.autoguide.AutoNormal(model)
``````

vs.

``````def guide(data):
loc_z_q0   = pyro.param("loc_z_q0", torch.zeros(K))
scale_z_q0 = pyro.param("scale_z_q0", torch.eye(K, K), constraint=constraints.positive)
z0         = pyro.sample("z0", dist.MultivariateNormal(loc_z_q0, scale_z_q0))

return {"z0": z0}

def model(data):
"""Model for data = s @ z + sigma."""
s = pyro.param("s", torch.randn(G, K))

z0_loc   = pyro.param("z0_loc", torch.zeros(K))
z0_scale = pyro.param("z0_scale", torch.eye(K, K), constraint=constraints.positive)
z0       = pyro.sample("z0", dist.MultivariateNormal(z0_loc, z0_scale))
sigma0 = pyro.param("sigma0", torch.eye(G, G), constraint=constraints.positive)

mean = s @ z0
mean = mean.squeeze()  # Code blows up otherwise

for i in range(len(data)):
pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mean, sigma0), obs=data[:, i])
``````
1. Even after a long training and a converged loss function, the params in `pyro.param("AutoNormal.locs.z0")` and `pyro.param("AutoNormal.scales.z0")` are both incorrect, i.e. they donâ€™t represent a MVN(0,1) distribution.
Moreover, `pyro.param("AutoNormal.scales.z0")` has a weird shape: only `K`, but pyro.param(â€śz0_scaleâ€ť) has `K x K`.
Hence, I guess something is systematically wrong in my code?

Minimum Working Example:

``````
import pyro
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
import torch

pyro.set_rng_seed(4321)
pyro.clear_param_store()

# Data
G = 20    # Features
N = 1000  # Samples
K = 5     # Latent Dimensions
s_true = torch.from_numpy(np.random.randint(0, 3, size=(G, K))).float()
z_true = torch.from_numpy(np.random.multivariate_normal(np.zeros(K), np.eye(K), size=N).T).float()
data = s_true @ z_true

def model(data):
s = pyro.param("s", torch.randn(G, K))

z0_loc   = pyro.param("z0_loc", torch.zeros(K))
z0_scale = pyro.param("z0_scale", torch.eye(K, K), constraint=constraints.positive)
z0       = pyro.sample("z0", dist.MultivariateNormal(z0_loc, z0_scale))
sigma0 = pyro.param("sigma0", torch.eye(G, G), constraint=constraints.positive)

mean = s @ z0
mean = mean.squeeze()

for i in range(len(data)):
pyro.sample("obs_{}".format(i), dist.MultivariateNormal(mean, sigma0), obs=data[:, i])

def guide(data):
loc_z_q0   = pyro.param("loc_z_q0", torch.zeros(K))
scale_z_q0 = pyro.param("scale_z_q0", torch.eye(K, K), constraint=constraints.positive)
z0         = pyro.sample("z0", dist.MultivariateNormal(loc_z_q0, scale_z_q0))

return {"z0": z0}

# OR
# guide = pyro.infer.autoguide.AutoNormal(model)

# setup the optimizer
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

steps = 12000
losses = np.zeros(steps)
for i in range(steps):
losses[i] = svi.step(data)
if i % 100 == 0:
print(f"[iteration {i}] loss: {losses[i]:.4f}")
``````

this is not a valid constraint for positive definite matrices. it simply enforces positive entries. you want `constraints.positive_definite`

1 Like

Thanks a bunch! That makes the loss function much more stable.
However, I observe that the code now tends to raise

`ValueError: Expected parameter covariance_matrix (Tensor of shape (50, 50)) of distribution MultivariateNormal(loc: torch.Size([50]), covariance_matrix: torch.Size([50, 50])) to satisfy the constraint PositiveDefinite(), ...`

quite often. It seems to be related to the learning rate. Is this expected to happen frequently or what are the common causes for this?

this is expected to happen if e.g. your learning rate is too high. using 64 bit precision can also mitigate. see this tutorial for other generic tips

1 Like