Hi Pyro devs!
I am trying to get some code where I learn an imputation distribution, to work with TraceTailAdaptive_ELBO
, since it seemed to work well with some of my other models.
However it seems hat I get shape errors when I try to run the model because of the added extra dimensions on the left, when I try to impute the data.
I have tried expanding the observed data, but then it complains about the shape of observed values.
Any ideas on how to change the model, so that it would work with vectorized estimators like TraceTailAdaptive_ELBO
?
Thanks in advance!
Minimalized Code:
import pyro
import torch
import pyro.distributions as dist
from pyro.contrib.autoguide import AutoMultivariateNormal
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, TraceTailAdaptive_ELBO, RenyiELBO
from tqdm import tqdm
data = dist.MultivariateNormal(loc=torch.tensor([1.0, 5.0, -3.0]), scale_tril=torch.tensor([[1.0, 0.0, 0.0],
[3.0, 1.0, 0.0],
[5.0, 10.0, 1.0]])).sample((1000,))
missing_lengths = torch.randint(0, 3, (1000,))
missing_mask = torch.arange(3)
missing_mask = missing_mask.unsqueeze(-2) <= missing_lengths.unsqueeze(-1)
data[~missing_mask] = 0
def model(data, missing_mask):
imp_loc = pyro.sample('imp_loc', dist.MultivariateNormal(torch.zeros(3), scale_tril=torch.eye(3)))
imp_scale = pyro.sample('imp_scale', dist.HalfCauchy(torch.tensor(1.)))
global_loc = pyro.sample('global_loc', dist.MultivariateNormal(torch.ones(3), torch.eye(3) * 0.1))
with pyro.plate('data', data.size(0), dim=-1):
imp = pyro.sample('imp', dist.MultivariateNormal(imp_loc, scale_tril=torch.eye(3) * imp_scale))
data_obs = data * missing_mask.float() + imp * (~missing_mask).float()
pyro.sample('obs', dist.MultivariateNormal(global_loc, scale_tril=torch.eye(3)), obs=data_obs)
guide = AutoMultivariateNormal(model)
optim = Adam(dict(lr=1e-3))
elbo = Trace_ELBO()
svi = SVI(model, guide, optim, elbo)
n_steps = 500
pg = tqdm(range(n_steps))
for ep in pg:
loss = svi.step(data, missing_mask)
pg.set_description(f"Epoch {ep}: {loss}")
If I change Trace_ELBO()
to TraceTailAdaptive_ELBO(num_particles=10, vectorize_particles=True)
in the above code for elbo
, I get the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
146 try:
--> 147 ret = self.fn(*args, **kwargs)
148 except (ValueError, RuntimeError):
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/messenger.py in _wraps(*args, **kwargs)
26 with self:
---> 27 return fn(*args, **kwargs)
28 _wraps.msngr = self
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/elbo.py in wrapped_fn(*args, **kwargs)
134 with pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting):
--> 135 return fn(*args, **kwargs)
136
<ipython-input-4-c7fc2e8bb656> in model(data, missing_mask)
5 with pyro.plate('data', data.size(0), dim=-1):
----> 6 imp = pyro.sample('imp', dist.MultivariateNormal(imp_loc, scale_tril=torch.eye(3) * imp_scale))
7 data_obs = data * missing_mask.float() + imp * (~missing_mask).float()
RuntimeError: The size of tensor a (3) must match the size of tensor b (10) at non-singleton dimension 0
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
<ipython-input-11-e328309c426f> in <module>
2 pg = tqdm(range(n_steps))
3 for ep in pg:
----> 4 svi.step(data, missing_mask)
5 loss = loss_estim.loss(model, guide, data, missing_mask)
6 pg.set_description(f"Epoch {ep}: {loss}")
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
97 # get loss and compute gradients
98 with poutine.trace(param_only=True) as param_capture:
---> 99 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
100
101 params = set(site["value"].unconstrained()
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
123 loss = 0.0
124 # grab a trace from the generator
--> 125 for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
126 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
127 loss += loss_particle / self.num_particles
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, *args, **kwargs)
163 if self.max_plate_nesting == float('inf'):
164 self._guess_max_plate_nesting(model, guide, *args, **kwargs)
--> 165 yield self._get_vectorized_trace(model, guide, *args, **kwargs)
166 else:
167 for i in range(self.num_particles):
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/elbo.py in _get_vectorized_trace(self, model, guide, *args, **kwargs)
145 return self._get_trace(self._vectorized_num_particles(model),
146 self._vectorized_num_particles(guide),
--> 147 *args, **kwargs)
148
149 @abstractmethod
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, *args, **kwargs)
50 """
51 model_trace, guide_trace = get_importance_trace(
---> 52 "flat", self.max_plate_nesting, model, guide, *args, **kwargs)
53 if is_validation_enabled():
54 check_if_enumerated(guide_trace)
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, *args, **kwargs)
42 guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
43 model_trace = poutine.trace(poutine.replay(model, trace=guide_trace),
---> 44 graph_type=graph_type).get_trace(*args, **kwargs)
45 if is_validation_enabled():
46 check_model_guide_match(model_trace, guide_trace, max_plate_nesting)
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
167 Calls this poutine and returns its trace instead of the function's return value.
168 """
--> 169 self(*args, **kwargs)
170 return self.msngr.get_trace()
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
151 six.reraise(exc_type,
152 exc_type(u"{}\n{}".format(exc_value, shapes)),
--> 153 traceback)
154 self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
155 return ret
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/six.py in reraise(tp, value, tb)
690 value = tp()
691 if value.__traceback__ is not tb:
--> 692 raise value.with_traceback(tb)
693 raise value
694 finally:
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
145 args=args, kwargs=kwargs)
146 try:
--> 147 ret = self.fn(*args, **kwargs)
148 except (ValueError, RuntimeError):
149 exc_type, exc_value, traceback = sys.exc_info()
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/poutine/messenger.py in _wraps(*args, **kwargs)
25 def _wraps(*args, **kwargs):
26 with self:
---> 27 return fn(*args, **kwargs)
28 _wraps.msngr = self
29 return _wraps
/usr/local/miniconda3/envs/probprog-sandbox/lib/python3.7/site-packages/pyro/infer/elbo.py in wrapped_fn(*args, **kwargs)
133 return fn(*args, **kwargs)
134 with pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting):
--> 135 return fn(*args, **kwargs)
136
137 return wrapped_fn
<ipython-input-4-c7fc2e8bb656> in model(data, missing_mask)
4 global_loc = pyro.sample('global_loc', dist.MultivariateNormal(torch.ones(3), torch.eye(3) * 0.1))
5 with pyro.plate('data', data.size(0), dim=-1):
----> 6 imp = pyro.sample('imp', dist.MultivariateNormal(imp_loc, scale_tril=torch.eye(3) * imp_scale))
7 data_obs = data * missing_mask.float() + imp * (~missing_mask).float()
8 pyro.sample('obs', dist.MultivariateNormal(global_loc, scale_tril=torch.eye(3)), obs=data_obs)
RuntimeError: The size of tensor a (3) must match the size of tensor b (10) at non-singleton dimension 0
Trace Shapes:
Param Sites:
Sample Sites:
num_particles_vectorized dist |
value 10 |
imp_loc dist 10 1 | 3
value 10 1 | 3
imp_scale dist 10 1 |
value 10 1 |
global_loc dist 10 1 | 3
value 10 1 | 3
data dist |
value 1000 |