Get multiple posterior samples using guide

I’ve searched around and only found out that if I want to do inference for my regression parameters, I can use guide() to obtain a sample based on the latest values after SVI. The problem is if I want to get 1 million samples, what is the most efficient way of doing this? I can use a for loop but that seems dumb.

I also tried to limit the return_sites using predictive, which seems work but I don’t know if I am correct.

num_samples = 10000
predictive = Predictive(model, guide=guide, num_samples=num_samples, return_sites=('b0', 'phi'))

If I didn’t limit the return_sites, my GPU will be out of memory since it generate predictions for each data point y.

Using Predictive with specified return_sites is completely correct.

You can use pyro.poutine.trace(guide).get_trace(*agrs, **kwargs) to get a sample directly from your guide.
In some cases you also might want to obtain variational distribution parameters directly using pyro.param("a").item(), rather than sampling from the distribution.

1 Like

Thanks for the quick response. Is there a way to run Predictive faster? It seems pretty slow.

@snowdustdj are you using the parallel=True option in Predictive?

I tried to use but it seems that requires small special handling for the Plate. Here is my model (a simplified version I’m using now to test some concepts before I dive deep) and I’m using some auto normal guide.

def model(y_tensor, n_tensor):
    All obs are splited into M segmetns. Each segment follows an NB distribution with different
    (alpha, beta) as in the dist.GammaPoisson dist. The purpose is to estimate all alpha s and beta s
    for all segments independtly but in this sigle code block (there is no dependency among segmetns for now, but later on 
    I may add partial pooling, features, etc).
    y_tensor: size N. Standing for N obs, each obs belongs to one defined segment. Data are ordered from seg 1 to seg N.
    n_tenosor: contains the obs count for seg 1, 2, ..., N with sum(n_tensor) = len(y_tensor).
    n_segments = len(n_tensor)
    b0 = pyro.sample("b0", dist.Normal(loc=torch.zeros((n_segments,1)), scale=5))
    phi = pyro.sample("phi", dist.HalfCauchy(2.0 * torch.ones((n_segments,1))))
    mu = torch.exp(b0)
    mu_reshape = mu.reshape(2, -1)
    uplift = pyro.deterministic('uplift', mu_reshape[1, :] - mu_reshape[0, :])
    beta = phi/mu
    alpha = phi
    # broad cast
    alpha_cast = alpha.repeat_interleave(n_tensor)
    beta_cast = beta.repeat_interleave(n_tensor)

    with pyro.plate("data", len(y_tensor)):
        pyro.sample("obs", dist.GammaPoisson(alpha_cast, beta_cast), obs=y_tensor)

# compute guide
y_tensor = torch.tensor(y_list, dtype=torch.float)
n_tensor = torch.tensor(n_list)

from pyro.infer.autoguide import AutoNormal, init_to_mean
from pyro.infer import SVI, Trace_ELBO

num_iters = 10000
guide = AutoNormal(model, init_loc_fn=init_to_mean)

svi = SVI(model,
          optim.Adam({"lr": .001}),
loss = []
for i in range(num_iters):
    elbo = svi.step(y_tensor, n_tensor)
    if i % 500 == 0:
        print("Elbo loss: {}".format(elbo))

# sample nodes
from pyro.infer import Predictive

num_samples = 100000
predictive = Predictive(model, guide=guide, num_samples=num_samples, return_sites=('b0', 'phi', 'uplift'))
pred_res_raw = predictive(y_tensor, n_tensor)

It takes 6 mins to get 100k samples. If I set parallel=True, it compalins
ValueError: Shape mismatch inside plate('_num_predictive_samples') at site b0 dim -2, 1000 vs 40

When parallel=False, Predictive has to run your model once per sample, which as you are seeing will be very slow for large numbers of samples.

The reason you are seeing the error with parallel=True is that Predictive(..., parallel=True) uses an additional pyro.plate to perform vectorized sampling using PyTorch’s broadcasting semantics. More generally, Pyro’s inference algorithms may introduce additional batch dimensions to the left of all plate and event dimensions in your model, as described in the tensor shapes tutorial. You should be able to avoid most such errors entirely as long as you apply the rules of thumb in the summary section of the tutorial when writing Pyro code.

For example, in your model, you will need to annotate the batch dimension in sites b0 and phi with either .to_event or plate and reshape mu in a way that is compatible with extra batch dimensions:

with pyro.plate("segments", n_segments):
    b0 = pyro.sample("b0", dist.Normal(0, 5))
    phi = pyro.sample("phi", dist.HalfCauchy(2.))
mu = torch.exp(b0)
mu_reshape = mu.reshape(*(mu.shape[:-1] + (2, n_segments // 2)))
uplift = pyro.deterministic("uplift", mu_reshape[..., 1, :] - mu_reshape[..., 0, :], event_dim=1)