Minibatch plate AssertionError in `init_loc = self._adjust_plates(unconstrained, event_dim)`

Hi all

I have a problem with a variable in the minibatch plate

pyro.plate("regions_plate", size=self.n_regions, dim=-3, subsample=ind)
with regions_plate as ind:
    x = pyro.sample(
        "x",
        dist.Gamma(torch.ones([1, 1, 1]), torch.ones([1, 1, 1]))
    )

Forward sampling works as expected (correct sample shapes), however AutoNormal/AutoHierarchicalNormalMessenger guides fails to create parameters for this variable:

File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/primitives.py:163, in sample(name, fn, *args, **kwargs)
    146 msg = {
    147     "type": "sample",
    148     "name": name,
   (...)
    160     "continuation": None,
    161 }
    162 # apply the stack and return its return value
--> 163 apply_stack(msg)
    164 return msg["value"]

File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/poutine/runtime.py:213, in apply_stack(initial_msg)
    209 for frame in reversed(stack):
    211     pointer = pointer + 1
--> 213     frame._process_message(msg)
    215     if msg["stop"]:
    216         break

File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/poutine/messenger.py:154, in Messenger._process_message(self, msg)
    152 method = getattr(self, "_pyro_{}".format(msg["type"]), None)
    153 if method is not None:
--> 154     return method(msg)
    155 return None

File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/poutine/guide.py:62, in GuideMessenger._pyro_sample(self, msg)
     60 prior = msg["fn"]
     61 msg["infer"]["prior"] = prior
---> 62 posterior = self.get_posterior(msg["name"], prior)
     63 if isinstance(posterior, torch.Tensor):
     64     posterior = dist.Delta(posterior, event_dim=prior.event_dim)

File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:277, in AutoHierarchicalNormalMessenger.get_posterior(self, name, prior)
    274     transform = biject_to(prior.support)
    275 if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
    276     # If hierarchical_sites not specified all sites are assumed to be hierarchical
--> 277     loc, scale, weight = self._get_params(name, prior)
    278     loc = loc + transform.inv(prior.mean) * weight
    279     posterior = dist.TransformedDistribution(
    280         dist.Normal(loc, scale).to_event(transform.domain.event_dim),
    281         transform.with_cache(),
    282     )

File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:306, in AutoHierarchicalNormalMessenger._get_params(self, name, prior)
    304 constrained = self.init_loc_fn({"name": name, "fn": prior}).detach()
    305 unconstrained = transform.inv(constrained)
--> 306 init_loc = self._adjust_plates(unconstrained, event_dim)
    307 init_scale = torch.full_like(init_loc, self._init_scale)
    308 if self.weight_type == "scalar":
    309     # weight is a single value parameter

File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:77, in AutoMessenger._adjust_plates(self, value, event_dim)
     75             value = value.mean(dim, keepdim=True)
     76     elif f.size != full_size:
---> 77         value = periodic_repeat(value, full_size, dim).contiguous()
     78 for dim in range(value.dim() - event_dim):
     79     value = value.squeeze(0)

File /miniconda3farm5/envs/test_scvi16_cuda111/lib/python3.9/site-packages/pyro/ops/tensor_utils.py:79, in periodic_repeat(tensor, size, dim)
     59 def periodic_repeat(tensor, size, dim):
     60     """
     61     Repeat a ``period``-sized tensor up to given ``size``. For example::
     62 
   (...)
     77     :param int dim: The tensor dimension along which to repeat.
     78     """
---> 79     assert isinstance(size, int) and size >= 0
     80     assert isinstance(dim, int)
     81     if dim >= 0:

AssertionError: 

The same model also has a second minibatch plate in dim =-4 and a third plate in dim=-2 for certain global variables. Variables using both of the other plates are created without issues. The problem persists with any variable in regions_plate.

Would be great to get any help with this.

Turns out this assertion failed because size=self.n_regions was not type int. Surprisingly, this code works just fine when this plate is not used for minibatching.

Hi @vitkl, so have you resolved the problem by making self.n_regions an int? If you still need help, would you mind posting more model details?

hi @fritzo

Thanks for getting back. The problem is fixed by making self.n_regions an int - but it took 8+hours to find, so would be great if this type check is done earlier - specifically when the plate is created.

1 Like