I implement model with plate statement , and I want to train SVI on GPUs . For example , this model is the simple version
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self,obs):
with pyro.plate('B',size=10):
x = pyro.sample('rv',dist.Normal(torch.zeros(10),torch.ones(10)) , obs=obs)
return x
cuda = torch.device("cuda")
model = Model()
model = nn.DataParallel(model,device_ids=[0,1])
model.to(cuda)
model(None)
ValueError: Caught ValueError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
output = module(*input, **kwargs)
File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/torch/nn/modules/module.py", line 532, in __call__
result = self.forward(*input, **kwargs)
File "<ipython-input-3-0751ba414b33>", line 6, in forward
with pyro.plate('B',size=10):
File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py", line 18, in __enter__
super(PlateMessenger, self).__enter__()
File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/pyro/poutine/indep_messenger.py", line 83, in __enter__
self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim)
File "/home/odie/.conda/envs/pyro/lib/python3.7/site-packages/pyro/poutine/runtime.py", line 32, in allocate
raise ValueError('duplicate plate "{}"'.format(name))
ValueError: duplicate plate "B"
Can’t do forward and compute ELBO . How can I do this in GPUs ?
Thanks !