SVI for model with discrete latent

Hi all, first of all, thanks for building Pyro, it is a lot of fun to work with!

I have a model and guide which have stochastic control flow in them and I am trying to use SVI. The full MWE is:

import torch
import pyro
import pyro.distributions as dist

from torch.distributions import constraints


def model():
    branch = pyro.sample(
        "branch",
        dist.Bernoulli(probs=torch.tensor([0.5])),
    )
    if branch == 1:
        z1 = pyro.sample("z1", dist.Normal(-3, 1))
    else:
        z1 = pyro.sample("z1", dist.Normal(3, 1))

    x = pyro.sample("x", dist.Normal(z1, 2), obs=torch.tensor(2.0))


def guide():
    weight = pyro.param(
        "weight", torch.tensor(0.5), constraint=constraints.interval(0, 1)
    )
    m1 = pyro.param("m1", torch.tensor(0.0))
    s1 = pyro.param("s1", torch.tensor(1.0), constraint=constraints.positive)
    m2 = pyro.param("m2", torch.tensor(0.0))
    s2 = pyro.param("s2", torch.tensor(1.0), constraint=constraints.positive)

    branch = pyro.sample(
        "branch",
        dist.Bernoulli(probs=torch.tensor([weight])),
    )
    if branch == 1:
        z1 = pyro.sample("z1", dist.Normal(m1, s1))
    else:
        z1 = pyro.sample("z1", dist.Normal(m2, s2))


def train_svi(model, guide, optim, num_iterations):
    svi = pyro.infer.SVI(model, guide, optim, loss=pyro.infer.Trace_ELBO())

    pyro.clear_param_store()

    for _ in range(num_iterations):
        loss = svi.step()


def main():
    pyro.set_rng_seed(0)
    train_svi(model, guide, pyro.optim.Adam({"lr": 0.05}), 1000)
    param_store = pyro.get_param_store()
    for name, val in param_store.items():
        print(f"{name}: {val}")


if __name__ == "__main__":
    main()

I am specifically interested in the behaviour of VI for this model so while I assume I could somehow use enumeration, I would like to investigate how VI behaves.

The program outputs:

weight: 0.5
m1: -2.095879554748535
s1: 0.9261342883110046
m2: 2.8922441005706787
s2: 0.8191512823104858

I have two question:

  1. So the parameters for the normal distribution on each branch get correctly optimized, however, it looks like the gradient for weight is always 0 so it doesn’t change. Is this the intended behaviour when using Trace_ELBO? By my calculations, if I try to calculate the gradient by hand using the score function estimator then it shouldn’t be equal to 0. So I am wondering whether I am doing something wrong and I am not using the correct ELBO or something.
  2. What is the gradient estimator that Trace_ELBO uses? The documentation names Automated Variational Inferencein Probabilistic Programming as a reference, which would suggest it uses the score function estimator for each latent variable. However, reading through the implementation here and here it looks like the implementation somehow uses information about whether a given latent variable can be reparameterized and based on that information builds a different surrogate objective function which can be used with autodiff. So is it an adapted version of the algorithm by Wingate and Weber?

you want to change this to

probs=weight

as it messes up gradients. weight is already a tensor.
i’m actually surprised this doesn’t error.

as to your second question yes, it is an adapted version of wingate and weber that uses reparameterized gradient when possible and falls back on score functions gradients when not possible.

1 Like

Thanks a lot @martinjankowiak, it now works without a problem!