Making imputation work with vectorized ELBO estimators

Hi Pyro devs!

I am trying to get some code where I learn an imputation distribution, to work with TraceTailAdaptive_ELBO, since it seemed to work well with some of my other models.
However it seems hat I get shape errors when I try to run the model because of the added extra dimensions on the left, when I try to impute the data.
I have tried expanding the observed data, but then it complains about the shape of observed values.

Any ideas on how to change the model, so that it would work with vectorized estimators like TraceTailAdaptive_ELBO?

Thanks in advance! :slight_smile:

Minimalized Code:

import pyro
import torch
import pyro.distributions as dist
from pyro.contrib.autoguide import AutoMultivariateNormal
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceTailAdaptive_ELBO, RenyiELBO
from tqdm import tqdm

data = dist.MultivariateNormal(loc=torch.tensor([1.0, 5.0, -3.0]), scale_tril=torch.tensor([[1.0, 0.0, 0.0], 
                                                                                            [3.0, 1.0, 0.0],
                                                                                            [5.0, 10.0, 1.0]])).sample((1000,))

missing_lengths = torch.randint(0, 3, (1000,))
missing_mask = torch.arange(3)
missing_mask = missing_mask.unsqueeze(-2) <= missing_lengths.unsqueeze(-1)
data[~missing_mask] = 0

def model(data, missing_mask):
    imp_loc = pyro.sample('imp_loc', dist.MultivariateNormal(torch.zeros(3), scale_tril=torch.eye(3)))
    imp_scale = pyro.sample('imp_scale', dist.HalfCauchy(torch.tensor(1.)))
    global_loc = pyro.sample('global_loc', dist.MultivariateNormal(torch.ones(3), torch.eye(3) * 0.1))
    with pyro.plate('data', data.size(0), dim=-1):
        imp = pyro.sample('imp', dist.MultivariateNormal(imp_loc, scale_tril=torch.eye(3) * imp_scale))
        data_obs = data * missing_mask.float() + imp * (~missing_mask).float()
        pyro.sample('obs', dist.MultivariateNormal(global_loc, scale_tril=torch.eye(3)), obs=data_obs)

guide = AutoMultivariateNormal(model)
optim = Adam(dict(lr=1e-3))
elbo = Trace_ELBO()
svi = SVI(model, guide, optim, elbo)

n_steps = 500
pg = tqdm(range(n_steps))
for ep in pg:
    loss = svi.step(data, missing_mask)
    pg.set_description(f"Epoch {ep}: {loss}")

If I change Trace_ELBO() to TraceTailAdaptive_ELBO(num_particles=10, vectorize_particles=True) in the above code for elbo, I get the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    146             try:
--> 147                 ret = self.fn(*args, **kwargs)
    148             except (ValueError, RuntimeError):

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/messenger.py in _wraps(*args, **kwargs)
     26             with self:
---> 27                 return fn(*args, **kwargs)
     28         _wraps.msngr = self

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/elbo.py in wrapped_fn(*args, **kwargs)
    134             with pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting):
--> 135                 return fn(*args, **kwargs)
    136 

<ipython-input-4-c7fc2e8bb656> in model(data, missing_mask)
      5     with pyro.plate('data', data.size(0), dim=-1):
----> 6         imp = pyro.sample('imp', dist.MultivariateNormal(imp_loc, scale_tril=torch.eye(3) * imp_scale))
      7         data_obs = data * missing_mask.float() + imp * (~missing_mask).float()

RuntimeError: The size of tensor a (3) must match the size of tensor b (10) at non-singleton dimension 0

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
<ipython-input-11-e328309c426f> in <module>
      2 pg = tqdm(range(n_steps))
      3 for ep in pg:
----> 4     svi.step(data, missing_mask)
      5     loss = loss_estim.loss(model, guide, data, missing_mask)
      6     pg.set_description(f"Epoch {ep}: {loss}")

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
     97         # get loss and compute gradients
     98         with poutine.trace(param_only=True) as param_capture:
---> 99             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    100 
    101         params = set(site["value"].unconstrained()

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    123         loss = 0.0
    124         # grab a trace from the generator
--> 125         for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
    126             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
    127             loss += loss_particle / self.num_particles

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, *args, **kwargs)
    163             if self.max_plate_nesting == float('inf'):
    164                 self._guess_max_plate_nesting(model, guide, *args, **kwargs)
--> 165             yield self._get_vectorized_trace(model, guide, *args, **kwargs)
    166         else:
    167             for i in range(self.num_particles):

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/elbo.py in _get_vectorized_trace(self, model, guide, *args, **kwargs)
    145         return self._get_trace(self._vectorized_num_particles(model),
    146                                self._vectorized_num_particles(guide),
--> 147                                *args, **kwargs)
    148 
    149     @abstractmethod

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, *args, **kwargs)
     50         """
     51         model_trace, guide_trace = get_importance_trace(
---> 52             "flat", self.max_plate_nesting, model, guide, *args, **kwargs)
     53         if is_validation_enabled():
     54             check_if_enumerated(guide_trace)

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, *args, **kwargs)
     42     guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
     43     model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
---> 44                                 graph_type=graph_type).get_trace(*args, **kwargs)
     45     if is_validation_enabled():
     46         check_model_guide_match(model_trace, guide_trace, max_plate_nesting)

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    167         Calls this poutine and returns its trace instead of the function's return value.
    168         """
--> 169         self(*args, **kwargs)
    170         return self.msngr.get_trace()

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    151                 six.reraise(exc_type,
    152                             exc_type(u"{}\n{}".format(exc_value, shapes)),
--> 153                             traceback)
    154             self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
    155         return ret

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
    690                 value = tp()
    691             if value.__traceback__ is not tb:
--> 692                 raise value.with_traceback(tb)
    693             raise value
    694         finally:

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    145                                       args=args, kwargs=kwargs)
    146             try:
--> 147                 ret = self.fn(*args, **kwargs)
    148             except (ValueError, RuntimeError):
    149                 exc_type, exc_value, traceback = sys.exc_info()

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/messenger.py in _wraps(*args, **kwargs)
     25         def _wraps(*args, **kwargs):
     26             with self:
---> 27                 return fn(*args, **kwargs)
     28         _wraps.msngr = self
     29         return _wraps

/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/elbo.py in wrapped_fn(*args, **kwargs)
    133                 return fn(*args, **kwargs)
    134             with pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting):
--> 135                 return fn(*args, **kwargs)
    136 
    137         return wrapped_fn

<ipython-input-4-c7fc2e8bb656> in model(data, missing_mask)
      4     global_loc = pyro.sample('global_loc', dist.MultivariateNormal(torch.ones(3), torch.eye(3) * 0.1))
      5     with pyro.plate('data', data.size(0), dim=-1):
----> 6         imp = pyro.sample('imp', dist.MultivariateNormal(imp_loc, scale_tril=torch.eye(3) * imp_scale))
      7         data_obs = data * missing_mask.float() + imp * (~missing_mask).float()
      8         pyro.sample('obs', dist.MultivariateNormal(global_loc, scale_tril=torch.eye(3)), obs=data_obs)

RuntimeError: The size of tensor a (3) must match the size of tensor b (10) at non-singleton dimension 0
                Trace Shapes:            
                 Param Sites:            
                Sample Sites:            
num_particles_vectorized dist         |  
                        value      10 |  
                 imp_loc dist 10    1 | 3
                        value 10    1 | 3
               imp_scale dist 10    1 |  
                        value 10    1 |  
              global_loc dist 10    1 | 3
                        value 10    1 | 3
                    data dist         |  
                        value    1000 |  

Hmm… sounds not unlike the kind of errors I’m trying to deal with. I’m looking forward to seeing how your question is answered. Sorry that I can’t help, myself…

(in other words: +1 on this thread.)