Generating samples from the posterior predictive without subsampling

Hi there, fairly new pyro user here!

So far, I’ve created a PyroModule:

class SurvivalModel(PyroModule):
    def __init__(self, in_features, out_features, subsample_size=64):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(
            pyro.distributions.Normal(0., 100.).expand([out_features, in_features]).to_event(2)
        )
        self.linear.bias = PyroSample(
            pyro.distributions.Normal(365., 100.).expand([out_features]).to_event(1)
        )
        self.subsample_size = subsample_size

    def forward(self, x, y=None, truncation_label=None):
        mean = self.linear(x).squeeze(-1)
        
        with pyro.plate("data", size=x.shape[0], subsample_size=self.subsample_size) as ind:
             ...

In the above, I’ve specified a subsample_size.

However, when I was trying to generate samples from the posterior predictive using Predictive, I started to run into problems with the following code:

predictive = Predictive(
    model, 
    guide=guide, 
    num_samples=800,
    return_sites=("linear.weight", "obs", "_RETURN")
)
samples = predictive(
    x=X_test, 
    y=None, 
    truncation_label=None 
)

In the above, the model was still subsampling from X_test instead of using the entire X_test. Am I doing something wrong in my approach?

Any feedback would be appreciated!

Ended up refactoring to:

def forward(self, x, y=None, truncation_label=None, subsample=True):
    mean = self.linear(x).squeeze(-1)
    # Conditionally independent - perhaps add back in x.shape[0] arg
    with pyro.plate(
        "data", 
        size=x.shape[0], 
        subsample_size=self.subsample_size if subsample else None
    ) as ind:
       ...

Not sure if that is the prescribed way of doing it. Any feedback would still greatly be appreciated!

Hi @alim1990
I am also new to Pyro.

Looking at the code I guess you are trying to use minibatched data while passing in complete x and y.

Why not do the other way around?
Pass in minibatched data to your model and use pyro.plate merely to scale things properly. This will also make sure the model is not missing out training on any data. This will also speed up training because of not computing mean for entire dataset, to compute only for the data we pass in as minibatch.

In code, it will be something like -

for batch in DataLoader(batch_size=64):
    x, y = batch
    loss = svi.loss(model, guide, x, full_size=num_of_data_samples, y=y)
    
def forward(self, x, full_size, y=None, truncation_label=None):
    mean = self.linear(x).squeeze(-1)
    with pyro.plate(
        "data", 
        size=full_size, 
        subsample_size=x.shape[0]
    ) as ind:

samples = predictive(
    x=X_test, 
    full_size=X_test.shape[0],
    y=None, 
    truncation_label=None 
)

Thanks