Predictive and functions that return tuples

Hi all, I have a model function that returns two variables (i.e. it returns a tuple). I want to use Predictive to generate posterior samples, including these returned variables. A simplified version is something like this:

import pyro
from pyro.infer import Predictive
import pyro.distributions as dist
import torch

def model():
    x = pyro.sample("x", dist.Normal(torch.tensor(0.0), torch.tensor(1.0)))
    y = pyro.sample("y", dist.Normal(torch.tensor(0.0), torch.tensor(1.0)))
    
    return x, y

def guide():
    pyro.sample("x", dist.Normal(torch.tensor(-0.5), torch.tensor(0.1)))
    pyro.sample("y", dist.Normal(torch.tensor(0.5), torch.tensor(0.1)))
       
    
pred = Predictive(model, guide=guide, num_samples=100, return_sites=["_RETURN"])

sam = pred() ## ERROR!!!

In think the error occurs because tuples are not tensors and cannot be stacked. Is there a simple way to fix this?

Thanks!

can you share the stack trace?

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[19], line 19
     14     pyro.sample("y", dist.Normal(torch.tensor(0.5), torch.tensor(0.1)))
     17 pred = Predictive(model, guide=guide, num_samples=100, return_sites=["_RETURN"])
---> 19 sam = pred() ## ERROR!!!

File ~/Projects/TRM/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Projects/TRM/venv/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs)
    263     return_sites = None if not return_sites else return_sites
    264     posterior_samples = _predictive(
    265         self.guide,
    266         posterior_samples,
   (...)
    271         model_kwargs=kwargs,
    272     )
--> 273 return _predictive(
    274     self.model,
    275     posterior_samples,
    276     self.num_samples,
    277     return_sites=return_sites,
    278     parallel=self.parallel,
    279     model_args=args,
    280     model_kwargs=kwargs,
    281 )

File ~/Projects/TRM/venv/lib/python3.10/site-packages/pyro/infer/predictive.py:127, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
    124     return_site_shapes["_RETURN"] = shape
    126 if not parallel:
--> 127     return _predictive_sequential(
    128         model,
    129         posterior_samples,
    130         model_args,
    131         model_kwargs,
    132         num_samples,
    133         return_site_shapes,
    134         return_trace=False,
    135     )
    137 trace = poutine.trace(
    138     poutine.condition(vectorize(model), reshaped_samples)
    139 ).get_trace(*model_args, **model_kwargs)
    140 predictions = {}

File ~/Projects/TRM/venv/lib/python3.10/site-packages/pyro/infer/predictive.py:61, in _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace)
     59     return collected
     60 else:
---> 61     return {
     62         site: torch.stack([s[site] for s in collected]).reshape(shape)
     63         for site, shape in return_site_shapes.items()
     64     }

File ~/Projects/TRM/venv/lib/python3.10/site-packages/pyro/infer/predictive.py:62, in <dictcomp>(.0)
     59     return collected
     60 else:
     61     return {
---> 62         site: torch.stack([s[site] for s in collected]).reshape(shape)
     63         for site, shape in return_site_shapes.items()
     64     }

TypeError: expected Tensor as element 0 in argument 0, but got tuple

A simple solution is to wrap model in an auxiliary function model_wrapper as follows

def model_wrapper():
    x, y = model()
    pyro.deterministic("x", x)
    pyro.deterministic("y", y)

pred = Predictive(model_wrapper, guide=guide, return_sites=["x", "y"])
sam = pred()

x, y = sam["x"], sam["y"]

Obviously, I could do this also directly in model, but then I would have to change a lot of other stuff.