Pyro.plate in multiple GPUs

I implement model with plate statement , and I want to train SVI on GPUs . For example , this model is the simple version

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
    def forward(self,obs):
        with pyro.plate('B',size=10):
            x = pyro.sample('rv',dist.Normal(torch.zeros(10),torch.ones(10)) , obs=obs)
        return x

cuda = torch.device("cuda")
model = Model()
model = nn.DataParallel(model,device_ids=[0,1])


ValueError: Caught ValueError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/torch/nn/parallel/", line 60, in _worker
    output = module(*input, **kwargs)
  File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/torch/nn/modules/", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "<ipython-input-3-0751ba414b33>", line 6, in forward
    with pyro.plate('B',size=10):
  File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/pyro/poutine/", line 18, in __enter__
    super(PlateMessenger, self).__enter__()
  File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/pyro/poutine/", line 83, in __enter__
    self.dim = _DIM_ALLOCATOR.allocate(, self.dim)
  File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/pyro/poutine/", line 32, in allocate
    raise ValueError('duplicate plate "{}"'.format(name))
ValueError: duplicate plate "B"

Can’t do forward and compute ELBO . How can I do this in GPUs ?

Thanks !

Thanks for the simple example. This is probably a bug, since we don’t do much testing for multiple-GPU functionality. Can you open a GitHub issue with this and any other DataParallel-related failures you’re seeing? Is it possible in your case to work around by performing only deterministic computations in DataParallel modules?:

class Model(nn.Module):
    def __init__(self):
        # here ExpensiveNN is entirely deterministic
        self.expensive_nn = nn.DataParallel(ExpensiveNN(...), ..., dim=-1)
        super(Model, self).__init__()
    def forward(self,obs):
        with pyro.plate('B',size=10, dim=-1):
            likelihood_params = self.expensive_nn(...)
            x = pyro.sample('rv',dist.Normal(*likelihood_params) , obs=obs)
        return x