Correct MAP guide without using automatic guide generation


I am trying to understand how to build a maximum a posteriori (MAP) guide without using automatic guide generation. Below is a basic model and two potential guides (1 & 2). After estimation, I am interested in the values of both alpha and theta.

import pyro
import pyro.distributions as dist
import torch
from torch.distributions import constraints
M = 2
N = 3

def model():
    # hyperparameter
    alpha = torch.ones(M)
    # theta ~ Dirichlet(alpha)
    with pyro.plate('n', N):
        pyro.sample('theta', dist.Dirichlet(alpha)) 

This guide does not look correct but it allows me to inspect both alpha and theta:

def map_guide1():
    alpha = pyro.param('alpha', torch.ones(N, M), constraint=constraints.positive)
    theta = dist.Dirichlet(alpha).sample()
    with pyro.plate('n', N):
        pyro.sample('theta', dist.Delta(x, event_dim=1))  

This guide looks correct but it only gives me theta, not alpha:

def map_guide2():
    theta = pyro.param('theta', torch.ones(N, M), constraint=constraints.simplex)
    pyro.sample('theta', dist.Delta(theta)) 


See this old forum thread for a discussion of a question very similar to yours.

If you want a MAP estimate for alpha, you need to pyro.sample it from a prior distribution in your model and delta distribution in your guide just like your second guide does with theta, rather than treating it as a constant. If you instead want a maximum likelihood estimate of alpha, you need to wrap it with a pyro.param call in the model and omit it from your guide. Pyro’s inference algorithms can’t see or update parameters that don’t come from pyro.param or pyro.sample statements.