SVI doesn't change parameters

I am trying to build a model for a system with a probabilistic switch that controls the output of another random process. What I see is that the loss from SVI fluctuates without converging, and all the parameters stay constant. Is there something about Bernouilli that blocks gradients? Here is a simplified version of the code

def sample_model(data):
    with pyro.iarange('n', len(data)):
        rate = pyro.sample('rate', dist.LogNormal(1, 1).expand(len(data)))
        active = pyro.sample('active', dist.Bernoulli(0.5).expand(len(data)).independent(0))
        adj_rate = rate * active + 0.1
        result = pyro.sample('result', dist.Poisson(adj_rate), obs=data)

def sample_guide(data):
    with pyro.iarange('n', len(data)):
        rate = pyro.sample('rate', dist.LogNormal(1, 1).expand(len(data)))
        active_logit = pyro.param('active_logit', TT(norm(0, 1).rvs(size=len(data))))
        active = pyro.sample('active', dist.Bernoulli(logits=active_logit).independent(0))

pyro.clear_param_store()
optimizer = Adam({"lr": 0.005, "betas": (0.95, 0.999)})
svi = SVI(sample_model, sample_guide, optimizer, loss=Trace_ELBO())
results = np.array([np.concatenate([[svi.step(TT([1, 2, 0, 4, 5]))], pyro.param('active_logit').detach().numpy()])
                    for _ in range(100)])
plt.plot(results[:,0])
plt.show()
for i in range(results.shape[1] - 1):
    plt.plot(results[:,i + 1], label=str(i))
plt.show()

You’ll need to tell Pyro about your trainable parameters using pyro.param:

from torch.distributions import constraints

@pyro.broadcast  # <--- this allows you to omit .expand() below
def sample_model(data):
    loc = pyro.param('loc', torch.tensor(1.))
    scale = pyro.param('scale', torch.tensor(1.),
                       constraint=constraints.positive)
    p = pyro.param('p', torch.tensor(0.5),
                   constraint=constraints.unit_interval)
    with pyro.iarange('n', len(data)):
        rate = pyro.sample('rate', dist.LogNormal(loc, scale))
        assert rate.shape == data.shape
        active = pyro.sample('active', dist.Bernoulli(p))
        adj_rate = rate * active + 0.1
        result = pyro.sample('result', dist.Poisson(adj_rate), obs=data)

…and similar for the guide.

That worked, thanks. From reading the docs, I thought I was supposed to use the params only in the guide, not the model, but I guess I was mistaken. So the guide and model are now identical except for the observation.

Ravi

it depends on your model. there are model parameters and guide parameters, and those need to be wrapped in pyro.param statements if you want to learn them via inference.

So the guide and model are now identical except for the observation.

Sorry, your initial intuition was correct, I was simply demonstrating syntax. Usually you’d use fixed “prior” parameters in the model and learnable “posterior” parameters in the guide, and only the learnable parameters would require pyro.param statements.