Hi all
I have a problem with a variable in the minibatch plate
pyro.plate("regions_plate", size=self.n_regions, dim=-3, subsample=ind)
with regions_plate as ind:
x = pyro.sample(
"x",
dist.Gamma(torch.ones([1, 1, 1]), torch.ones([1, 1, 1]))
)
Forward sampling works as expected (correct sample shapes), however AutoNormal/AutoHierarchicalNormalMessenger guides fails to create parameters for this variable:
File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/primitives.py:163, in sample(name, fn, *args, **kwargs)
146 msg = {
147 "type": "sample",
148 "name": name,
(...)
160 "continuation": None,
161 }
162 # apply the stack and return its return value
--> 163 apply_stack(msg)
164 return msg["value"]
File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/poutine/runtime.py:213, in apply_stack(initial_msg)
209 for frame in reversed(stack):
211 pointer = pointer + 1
--> 213 frame._process_message(msg)
215 if msg["stop"]:
216 break
File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/poutine/messenger.py:154, in Messenger._process_message(self, msg)
152 method = getattr(self, "_pyro_{}".format(msg["type"]), None)
153 if method is not None:
--> 154 return method(msg)
155 return None
File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/poutine/guide.py:62, in GuideMessenger._pyro_sample(self, msg)
60 prior = msg["fn"]
61 msg["infer"]["prior"] = prior
---> 62 posterior = self.get_posterior(msg["name"], prior)
63 if isinstance(posterior, torch.Tensor):
64 posterior = dist.Delta(posterior, event_dim=prior.event_dim)
File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:277, in AutoHierarchicalNormalMessenger.get_posterior(self, name, prior)
274 transform = biject_to(prior.support)
275 if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
276 # If hierarchical_sites not specified all sites are assumed to be hierarchical
--> 277 loc, scale, weight = self._get_params(name, prior)
278 loc = loc + transform.inv(prior.mean) * weight
279 posterior = dist.TransformedDistribution(
280 dist.Normal(loc, scale).to_event(transform.domain.event_dim),
281 transform.with_cache(),
282 )
File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:306, in AutoHierarchicalNormalMessenger._get_params(self, name, prior)
304 constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
305 unconstrained = transform.inv(constrained)
--> 306 init_loc = self._adjust_plates(unconstrained, event_dim)
307 init_scale = torch.full_like(init_loc, self._init_scale)
308 if self.weight_type == "scalar":
309 # weight is a single value parameter
File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:77, in AutoMessenger._adjust_plates(self, value, event_dim)
75 value = value.mean(dim, keepdim=True)
76 elif f.size != full_size:
---> 77 value = periodic_repeat(value, full_size, dim).contiguous()
78 for dim in range(value.dim() - event_dim):
79 value = value.squeeze(0)
File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/ops/tensor_utils.py:79, in periodic_repeat(tensor, size, dim)
59 def periodic_repeat(tensor, size, dim):
60 """
61 Repeat a ``period``-sized tensor up to given ``size``. For example::
62
(...)
77 :param int dim: The tensor dimension along which to repeat.
78 """
---> 79 assert isinstance(size, int) and size >= 0
80 assert isinstance(dim, int)
81 if dim >= 0:
AssertionError:
The same model also has a second minibatch plate in dim =-4 and a third plate in dim=-2 for certain global variables. Variables using both of the other plates are created without issues. The problem persists with any variable in regions_plate
.
Would be great to get any help with this.