Inferring Discrete Latent Variable

I have tried to create a simple model where a discrete latent variable expert should be inferred.
However, the ELBO does not change when I do inference. I would expect it to decrease as the parameter of expert should increase. What am I missing?

import pyro
from pyro import poutine
from pyro.infer import TraceEnum_ELBO
from pyro.infer.autoguide import AutoDelta
from pyro.infer import MCMC, NUTS
import torch

def model(data=None):
    expert = pyro.sample("expert", pyro.distributions.Bernoulli(.5), infer={"enumerate": "parallel"}) 

    with pyro.plate("data"):
        correct = pyro.sample("correct", pyro.distributions.Bernoulli(expert), obs=data)

    return expert, correct

# Data
data = torch.ones(30)

from pyro.infer import SVI, TraceEnum_ELBO
from pyro.optim import ClippedAdam

# Define the number of optimization steps
n_steps = 1000

# Setup the optimizer
adam_params = {"lr": 0.1}

optimizer = ClippedAdam(adam_params)
guide = AutoNormal(poutine.block(model))

# Setup the inference algorithm
elbo = TraceEnum_ELBO(num_particles=1, max_plate_nesting=1)
svi = SVI(model, guide, optimizer, loss=elbo)

# Reset parameter values
pyro.clear_param_store()

# Do gradient steps
for step in range(n_steps):
    elbo = svi.step(data)

    if step % 100 == 0:
        print("[%d] ELBO: %.1f" % (step, elbo))

from pyro.infer import Predictive

predictive = Predictive(model, guide=guide, num_samples=100)

samples = predictive(data)

expertGuess = samples["expert"].mean(axis=0)
print(expertGuess)

you haven’t declared any parameters in your model (e.g. .5 is fixed). consequently the ELBO can be computed exactly and there’s nothing to learn.

1 Like

Thank you for the fast reply.
Forgetting to declare it as a parameter was the issue.

However, using MCMC (NUTS) to do the inference instead of SVI results in no samples generating. How can this be?

Furthermore, using infer={“enumerate”: “sequential”} instead of infer={“enumerate”: “parallel”}, gives the following user-warning:

Found vars in model but not guide: {‘expert’}

From the docs, it seems like “parallel” is only used for performance reasons. However, this does not seem to be the case.

def model(data=None):
    expert_prob = pyro.param("expert_prob", torch.tensor(.5), constraint=constraints.unit_interval)
    expert = pyro.sample("expert", pyro.distributions.Bernoulli(expert_prob), infer={"enumerate": "parallel"})    

    skills = torch.tensor([.2, .9])
    skill_level = skills[expert.long()]

    with pyro.plate("data"):
        correct = pyro.sample("correct", pyro.distributions.Bernoulli(skill_level), obs=data)

    return expert_prob, expert, correct

# Data
data = torch.ones(30)

pyro.clear_param_store()

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=3000, warmup_steps=500, num_chains=1)
mcmc.run(data)

# Show summary of inference results
mcmc.summary()

posterior_samples = mcmc.get_samples()
print(posterior_samples)

expert_prob_guess = pyro.param("expert_prob").item()
print(expert_prob_guess)

NUTS only operates on continuous latent variables. any discrete latent variables are integrated out. so not getting any samples from NUTS is expected.

i’m not sure about your user warning. can you please post a complete, self-contained script that reproduces this error (i.e. with import statements etc)?

Sure, here it is.

import pyro
from pyro import poutine
from pyro.infer import TraceEnum_ELBO
from pyro.infer.autoguide import AutoNormal, AutoDelta
from pyro.infer import MCMC, NUTS

import torch
from torch.distributions import constraints

def model(data=None):
    expert_prob = pyro.param("expert_prob", torch.tensor(.5), constraint=constraints.unit_interval)
    expert = pyro.sample("expert", pyro.distributions.Bernoulli(expert_prob), infer={"enumerate": "sequential"})

    skills = torch.tensor([.2, .9])
    skill_level = skills[expert.long()]

    with pyro.plate("data"):
        correct = pyro.sample("correct", pyro.distributions.Bernoulli(skill_level), obs=data)

    return expert_prob, expert, correct

# Data
data = torch.ones(30)

from pyro.infer import SVI, TraceEnum_ELBO
from pyro.optim import ClippedAdam

# Define the number of optimization steps
n_steps = 1000

# Setup the optimizer
adam_params = {"lr": 0.1}
optimizer = ClippedAdam(adam_params)
guide = AutoNormal(poutine.block(model))

# Setup the inference algorithm
elbo = TraceEnum_ELBO(num_particles=1, max_plate_nesting=1)
svi = SVI(model, guide, optimizer, loss=elbo)

# Reset parameter values
pyro.clear_param_store()

# Do gradient steps
for step in range(n_steps):
    elbo = svi.step(data)

    if step % 100 == 0:
        print("[%d] ELBO: %.1f" % (step, elbo))

from pyro.infer import Predictive

predictive = Predictive(model, guide=guide, num_samples=100)
samples = predictive(data)

expert_guess = samples["expert"].mean(axis=0)
print(expert_guess)

expert_prob_guess = pyro.param("expert_prob").item()
print(expert_prob_guess)

With regards to not including a parameter in the model, why is it not a issue with the following model, where beta is inferred? A mean and std. is inferred for the distribution of beta, but aren’t they “fixed”. However, the parameter of the Bernoulli distribution is not inferred “automatically” in the previous model.

def my_model(x, y_obs):
    beta = pyro.sample("beta", pyro.distributions.Normal(0., 1.))

    with pyro.plate("data", len(y_obs)):
        y = pyro.sample("y", pyro.distributions.Normal(beta*x, 1.), obs=y_obs)

    return y

sorry i haven’t had a chance to take a closer look at this and do not have access to a computer the next while so i’m afraid my response will take some time

1 Like

@Sollertis here’s a complete working example. sorry for the delay.

for one thing you can’t use Predictive the way you were using it.

import pyro
from pyro.infer import TraceEnum_ELBO
from pyro.infer import SVI, TraceEnum_ELBO
from pyro.optim import ClippedAdam
import torch
from torch.distributions import constraints


def model(data=None):
    expert_prob = pyro.param("expert_prob", torch.tensor(.5), constraint=constraints.unit_interval)
    expert = pyro.sample("expert", pyro.distributions.Bernoulli(expert_prob), infer={"enumerate": "parallel"})

    skills = torch.tensor([.2, .9])
    skill_level = skills[expert.long()]

    with pyro.plate("data"):
        correct = pyro.sample("correct", pyro.distributions.Bernoulli(skill_level), obs=data)


def guide(data=None):
    pass


# Data
data = torch.ones(30)

# Define the number of optimization steps
n_steps = 1001

# Setup the optimizer
optimizer = ClippedAdam({"lr": 0.03})

# Setup the inference algorithm
elbo = TraceEnum_ELBO(num_particles=1, max_plate_nesting=1)
svi = SVI(model, guide, optimizer, loss=elbo)

# Reset parameter values
pyro.clear_param_store()

# Do gradient steps
for step in range(n_steps):
    elbo = svi.step(data)

    if step % 100 == 0:
        print("[%d] ELBO: %.2f" % (step, elbo))

expert_prob_guess = pyro.param("expert_prob").item()
print("expert_prob", expert_prob_guess)

Thank you for your patience when answering these beginner questions.
I hope you can clear the following things up:

1.With regards to not including a parameter in the model, why is it not an issue with the following model, where beta is inferred?

def my_model(x, y_obs):
    beta = pyro.sample("beta", pyro.distributions.Normal(0., 1.)) // <- mean and std. is inferred. No need to use pyro.param

    with pyro.plate("data", len(y_obs)):
       y = pyro.sample("y", pyro.distributions.Normal(beta*x, 1.), obs=y_obs)

    return y

A mean and std. is inferred for the distribution of beta, but aren’t they “fixed” as well? However, the parameter of the Bernoulli distribution is not inferred “automatically” in the previous model. The parameter of Bernoulli distribution should explicited be stated as being a parameter with pyro.param.

2.In your complete working solution, you state, “for one thing you can’t use Predictive the way you were using it.” Why can’t I use Predictive this way? How should Predictive be used?

3.Using your complete working solution, changing infer={“enumerate”: “parallel”} to infer={“enumerate”: “sequential”} yields “random” values of the parameter “expert_prob” after inference. Why is this the case? I would expect them to behave the same (only difference in inference speed)

4.Why is the guide empty? In that case, how is the posterior distribution of “expert” to be inferred?

Thank you for the help.

[1] the point is that previously with a fixed bernoulli probability the only thing to do was inference (not parameter learning) and inference is trivial (i.e. can be done in closed form) for a simple model with only discrete latent variables

[3] sorry i didn’t realize it at the time but “model side” sequential enumeration is not implemented (only parallel enumeration). we just added a better error message to catch this, since the docs are not super explicit on this.

[4] the guide is empty because pyro essentially computes the optimal guide automatically via enumeration

let me respond to [2] separately

1 Like

@Sollertis regarding Predictive unfortunately Pyro doesn’t currently have great support for automating prediction when enumeration is used for discrete latent variables (although NumPyro does).

so the general recipe for doing this is using pyro.poutine and other utilities to do this semi-automatically. in particular the recipe is trace + replay followed by infer_discrete. trace+replay is basically what Predictive does under the hood.

in this particular case, where the only latent variables are discrete, this can be done as follows:

posterior_samples = torch.stack([infer_discrete(model, first_available_dim=-2)(data) for _ in range(100)])
posterior_samples = {'expert': posterior_samples}

predictive = Predictive(model, posterior_samples)
samples = predictive(data)
1 Like