I’m going through the example at MLE and MAP Estimation — Pyro Tutorials 1.7.0 documentation.

On that page it says the guide can be empty because there are no latent variables, and indeed that seems to work, although the reason why it works is a bit mysterious to me.

What I’m stuck on is changing that example to a case where there are some latent variables and the guide is not empty. So I decided to model a case where first a fair coin is flipped and then, based on the result, either another fair coin is flipped or a biased one, with unknown bias.

The code for my attempt is below, but unfortunately it doesn’t behave as I would expect it to. With the data I’ve given it, I’m fairly sure the maximum likelihood estimate for the unknown bias should be 1.0, but it gives an estimate of around 0.68 instead.

So I must have done something wrong somewhere. But what? How can I understand what the guide should contain in this simple case?

I am sorry for asking such basic questions.

```
import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
# data is 8 heads and 2 tails
data = torch.zeros(10)
data[0:8] = 1.0
def train(model, guide, lr=0.01):
pyro.clear_param_store()
optimiser = pyro.optim.Adam({"lr":lr})
svi = SVI(model, guide, optimiser, loss=Trace_ELBO())
n_steps = 101
for step in range(n_steps):
loss = svi.step(data)
if step % 50 == 0:
print(f"[iter {step}] loss: {loss:.4f}")
def model_mle(data):
f = pyro.param(
"latent_fairness",
torch.tensor(0.5),
constraint=constraints.unit_interval
)
for i in pyro.plate("data", data.size(0)):
coin_choice = pyro.sample(f"coin_choice_{i}", dist.Bernoulli(0.5))
if coin_choice.item() == 0.0:
bias = 0.5
else:
bias = f
pyro.sample(f"obs_{i}", dist.Bernoulli(bias), obs=data[i])
def guide_mle(data):
for i in pyro.plate("data", data.size(0)):
coin_choice = pyro.sample(f"coin_choice_{i}", dist.Bernoulli(0.5))
train(model_mle, guide_mle)
print(f"latent fairness estimate: {pyro.param('latent_fairness').item():.4f}")
```