 Another getting started question - MLE when there are latent variables

I’m going through the example at MLE and MAP Estimation — Pyro Tutorials 1.7.0 documentation.

On that page it says the guide can be empty because there are no latent variables, and indeed that seems to work, although the reason why it works is a bit mysterious to me.

What I’m stuck on is changing that example to a case where there are some latent variables and the guide is not empty. So I decided to model a case where first a fair coin is flipped and then, based on the result, either another fair coin is flipped or a biased one, with unknown bias.

The code for my attempt is below, but unfortunately it doesn’t behave as I would expect it to. With the data I’ve given it, I’m fairly sure the maximum likelihood estimate for the unknown bias should be 1.0, but it gives an estimate of around 0.68 instead.

So I must have done something wrong somewhere. But what? How can I understand what the guide should contain in this simple case?

I am sorry for asking such basic questions.

import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO

# data is 8 heads and 2 tails
data = torch.zeros(10)
data[0:8] = 1.0

def train(model, guide, lr=0.01):
pyro.clear_param_store()
svi = SVI(model, guide, optimiser, loss=Trace_ELBO())

n_steps = 101
for step in range(n_steps):
loss = svi.step(data)
if step % 50 == 0:
print(f"[iter {step}] loss: {loss:.4f}")

def model_mle(data):
f = pyro.param(
"latent_fairness",
torch.tensor(0.5),
constraint=constraints.unit_interval
)
for i in pyro.plate("data", data.size(0)):
coin_choice = pyro.sample(f"coin_choice_{i}", dist.Bernoulli(0.5))
if coin_choice.item() == 0.0:
bias = 0.5
else:
bias = f
pyro.sample(f"obs_{i}", dist.Bernoulli(bias), obs=data[i])

def guide_mle(data):
for i in pyro.plate("data", data.size(0)):
coin_choice = pyro.sample(f"coin_choice_{i}", dist.Bernoulli(0.5))

train(model_mle, guide_mle)
print(f"latent fairness estimate: {pyro.param('latent_fairness').item():.4f}")

i’m afraid i don’t really understand your model…

does this updated tutorial help? see the section at the end:

Thank you for the reply. The AutoGuide machinery sounds useful, but I guess before learning to use it I would still like to understand how to hand-design a guide for MLE in the case where I have a latent variable, so that I can understand what the guide should contain in that case.

My model was a somewhat arbitrary one, constructed only for the sake of having both a latent variable and a parameter. Let me explain it in more detail, though:

Suppose we have three coins. Coins ‘A’ and ‘B’ are known to be fair, but coin ‘C’ has an unknown bias f. We don’t have a prior over f. (That is, f is a parameter rather than a latent variable.) We are given data from 10 trials of the following experiment:

• first flip coin ‘A’ (known to be fair)
• if ‘A’ comes up heads, flip coin ‘B’ (known to be fair) and return the result
• otherwise, flip coin ‘C’ (with unknown bias f) and return the result

So the result from each trial is either heads or tails, but we don’t know if it came from coin ‘B’ or coin ‘C’, which is to say, we don’t know the value of the latent variable ‘A’. The goal is to get a maximum likelihood estimate of the unknown parameter f, given this data.

So then my question is just, what should I put in the guide, in that case? I made a guess at that in my original post, but it doesn’t behave as expected and I conclude I must be doing something wrong, so I’m looking for feedback on what.

i see. your model as written does not reflect what you describe. each coin is flipped only once. there should only be a single random variable corresponding to each flip. however you have two for each flip: one is latent and one is observed. instead you want something that (very) roughly looks like this (you can use an empty guide for this since there are no latent variables):

def model_mle(data):
f = pyro.param(
"latent_fairness",
torch.tensor(0.5),
constraint=constraints.unit_interval
)
prev_coin = None
for i in pyro.plate("data", data.size(0)):
if i > 0 and prev_coin == 0.0:  # or whatever the right condition is
bias = 0.5
else:
bias = f
pyro.sample(f"obs_{i}", dist.Bernoulli(bias), obs=data[i])
prev_coin = data[i].item()

note that all random variables are observed.

Now I’m even more confused. I think maybe I still didn’t explain the model properly. In the intended model not all of the coin flips are observed - the point of the exercise is to come up with a case where the guide will not be empty.

In the intended model, each data point involves two coin flips: first a flip of coin A (fair), and then either a flip of coin B (fair) or coin C (with bias f). But the first flip, of coin A, is not observed, and so we don’t know whether the data point is from coin B or coin C. (The data points are independent of one another. I only used a for loop so that I could write imperative code - I know that writing it in a vectorised way would be more efficient, but I don’t really care about that for this toy model.)

Maybe a different model would make more sense. This is really just the first thing I came up with. I’m just trying to come up with a simple minimal model in which it makes sense to do maximum likelihood inference but where the guide is not empty, so that I can figure out what the code ought to look like in that case.

I’m giving it data consisting of 8 heads and 2 tails, so the likelihood should be (1/4+f/2)^8*(3/4-f/2)^2, which is maximised by f=1.0, so that’s why I’m expecting that as the result.

It seems that maybe the guide needs an extra parameter per sample to model the posterior of each flip of coin ‘A’. It seems to converge to the correct solution with the following guide, though the convergence is incredibly slow, taking thousands of iterations even with quite a big learning rate.

The slow convergence makes me think I’m probably still doing it wrong, and I would really appreciate an example of maximum likelihood inference with a non-empty guide (with any model) so that I can try to work out what I should actually do.

def guide_mle(data):
for i in pyro.plate("data", data.size(0)):
c = pyro.param(
f"coin_choice_posterior_{i}",
torch.tensor(0.5),
constraint=constraints.unit_interval
)
coin_choice = pyro.sample(f"coin_choice_{i}", dist.Bernoulli(c))

i see. sorry for misunderstanding. you chose a rather difficult toy example… difficult because discrete latent variables are a bit of a challenge for variational inference. in particular if you try to learn a guide that contains discrete latent variables, the result is that you often end up with large gradient variance and difficult optimization (see svi part iii for details). in this case you can integrate/sum/enumerate out all the discrete latent variables using pyro’s machinery for automatic inference (see the enumeration and related tutorials for details).

here’s a complete snippet for your problem. note that there is no need to write down a guide because pyro effectively constructs one automatically.

import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO

# data is 8 heads and 2 tails
data = torch.zeros(10)
data[0:8] = 1.0

def train(model, lr=0.05):
pyro.clear_param_store()
# we use an empty guide
svi = SVI(model, lambda data: None, optimiser, loss=TraceEnum_ELBO(max_plate_nesting=0))

n_steps = 1001
for step in range(n_steps):
loss = svi.step(data)
if step % 50 == 0:
print(f"[iter {step}] loss: {loss:.4f}")

@pyro.infer.config_enumerate
def model_mle(data):
f = pyro.param(
"latent_fairness",
torch.tensor(0.5),
constraint=constraints.unit_interval
)
# we introduce dice_probs so we can index into it in a fully vectorized way
dice_probs = torch.stack([f, torch.tensor(0.5)])
for i in pyro.markov(range(len(data))):
coin_choice = pyro.sample(f"coin_choice_{i}", dist.Bernoulli(0.5))
pyro.sample(f"obs_{i}", dist.Bernoulli(dice_probs[coin_choice.long()]), obs=data[i])

train(model_mle)
print(f"latent fairness estimate: {pyro.param('latent_fairness').item():.4f}")

this should give a result like

latent fairness estimate: 0.9967

note that i used a markov context manager even though that isn’t strictly necessary here. another option would be to use a plate. you need to use something though if you want to do automatic inference though because otherwise you’ll be summing over exponentially many possible latent coin states