Guides and modules

I’m trying to figure out how to combine pyro modules with custom guides. What is the “pyronic” way of doing this? I came across this thread on the forum, in which a PyroModule is defined for both the model and guide. The model is the well-known fair coin example. My version of the example is listed below. Importantly, this does not work, in the sense that nothing is optimized in my training loop. If I use an AutoDelta guide instead of my PyroModule guide, I do get the right MAP estimate.

import pyro
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from pyro.nn import PyroModule, PyroParam, PyroSample
from pyro.infer.autoguide import AutoDelta
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from tqdm import trange
import torch

#pyro.clear_param_store() # for use in notebook

# implement the model as the forward method of a PyroModule
class Model(PyroModule):
    def __init__(self):
        self.fairness = PyroSample(prior=dist.Beta(0.5, 0.5))
    def forward(self, data):
        fairness = self.fairness
        with pyro.plate("obs", len(data)):
            return pyro.sample("data", dist.Bernoulli(fairness), obs=data)

# and implement the guide as the forward method of another module
class Guide(PyroModule):
    def __init__(self):
        self.alpha = PyroParam(torch.tensor(0.5), constraint=dist.constraints.positive)
        self.beta = PyroParam(torch.tensor(0.5), constraint=dist.constraints.positive)

        self.fairness = PyroSample(prior=dist.Beta(self.alpha, self.beta))

    def forward(self, data):
        return self.fairness
model = Model()
guide = Guide()

#guide = AutoDelta(model) # this DOES work

optim = Adam({"lr": 1e-2})

svi = SVI(model, guide, optim, loss=Trace_ELBO())

coin_data = torch.randint(2, size=(1000,), dtype=torch.float)

elbo = []
# training loop
for epoch in (pbar := trange(100)):
    loss = svi.step(coin_data)
    elbo.append((epoch, loss))
    pbar.set_description(f"average train loss: {loss:0.2f}")

store = pyro.get_param_store() # the param_store is empty!!

for k in store:
    print(k, store[k])

I’m hoping that someone can explain why this does not work and how to fix it, because I’m missing some basic understanding of how pyro works and how it should be used, even after working through many of the examples.

Thanks to anyone who is willing to help!!