Getting samples for all `pyro.sample` objects from a model

I have a simple model

def model(x, y=None,):   
    theta_0 = pyro.sample("t_0", pyro.distributions.Normal(0.0, 1.0))
    theta_1 = pyro.sample("t_1", pyro.distributions.Normal(0.0, 1.0))
    with pyro.plate("data", len(x)):
        return pyro.sample(
            "obs", pyro.distributions.Normal(x * theta_1 + theta_0, 1.0), obs=y
        )

I wish to create a dataset of theta_0, theta_1 and corresponding ys. I’m doing this for educational purposes to illustrate the samples drawn from a prior.

I can modify the return as follows

def model(x, y=None,):   
    theta_0 = pyro.sample("t_0", pyro.distributions.Normal(0.0, 1.0))
    theta_1 = pyro.sample("t_1", pyro.distributions.Normal(0.0, 1.0))
    with pyro.plate("data", len(x)):
        return theta_0, theta_1, pyro.sample(
            "obs", pyro.distributions.Normal(x * theta_1 + theta_0, 1.0), obs=y
        )

and use a for loop to create samples

import matplotlib.pyplot as plt
t0s = []
t1s = []
ys = []
for i in range(10):
    t0,t1, y = model(x)
    t0s.append(t0.item())
    t1s.append(t1.item())
    ys.append(y)
    plt.plot(x, x*t1+t0)

This will give me the desired plot.

Is there a better way to generate such samples?

Ideally, I’d like to use the same model for generating such data, and then using SVI to learn the guide.

Hi @nipun_batra, Using your original model you could draw a batch of samples using pyro.plate and extract them from a trace:

import pyro
import pyro.poutine as poutine
import matplotlib.pyplot as plt

with pyro.plate("samples", 10, dim=-2):
    trace = poutine.trace(model).get_trace(x)
t0s = trace.nodes["t_0"]["value"].squeeze()
t1s = trace.nodes["t_1"]["value"].squeeze()
ys = trace.nodes["obs"]["value"].squeeze()
for t0,t1,y in zip(t0s, t1s, ys):
    plt.plot(x, y)

Note we needed to set the outer plate’s dim=-2 to avoid collision with the inner “data” plate which defaults to dim=-1. See tensor shapes tutorial for details.

1 Like

Many thanks. This solves my issue.