Multiple executions of `model()` during single SVI step

Hey everyone,

I cannot find anything in the docs, so I am asking here: Why is the model() function sometimes executed multiple times during the SVI step?

More precisely, I observed that the model() function is executed multiple times when I have

  • num_particles=N, vectorize_particles=True =====> 2 Times (if N>=1)
  • num_particles=N, vectorize_particles=False =====> N Times

Hence:

  1. Is there a special reason why this is happning?
  2. Can I switch it off (and when should I)?

Thanks in advance!

As a min. working example I took the code from the tutorials and added some print statements. The output is the following (see last line of code below):

Before first step
- Running Model...
- Running Model...
After first step
- Running Model...
After second step

Rest of the code:

import math
import os
import torch
import torch.distributions.constraints as constraints
import pyro
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

pyro.clear_param_store()

data = []
for _ in range(6):
    data.append(torch.tensor(1.0))
for _ in range(4):
    data.append(torch.tensor(0.0))

def model(data):
    print("- Running Model...")
    alpha0 = torch.tensor(10.0)
    beta0 = torch.tensor(10.0)
    f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0))
    for i in range(len(data)):
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

def guide(data):
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0), constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0), constraint=constraints.positive)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

adam_params = {"lr": 0.0005, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)

# setup the inference algorithm
svi = SVI(model, guide, optimizer, loss=Trace_ELBO(retain_graph=False, num_particles=4, vectorize_particles=True))

print("Before first step")
svi.step(data)
print("After first step")
svi.step(data)
print("After second step")

if you add max_plate_nesting=0 to the Trace_ELBO constructor the “extra” execution will go away (for other models you may need max_plate_nesting>0).

basically pyro needs to figure out the plate structure so that it can figure out which tensor dimension can be safely used as a particle dimension. since you didn’t tell it and since pyro doesn’t do static code analysis, we run model once to figure out the required info

3 Likes