Hi,
I am having some trouble with setting data batch sizes and the number of ELBO particles.
If I have a model with two nested plates, and another plate outside those:
def fn1():
with pyro.plate('D', 3, dim=-1):
with pyro.plate('K', 2, dim=-2):
cov_factor = pyro.sample('cov_factor', dist.Normal(0.,1.))
cov_factor = cov_factor.transpose(-2,-1)
with pyro.plate('N', 100, subsample_size=batch_size, dim=-1):
print(cov_factor.shape)
obs = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(3),cov_factor=cov_factor,cov_diag=torch.ones(3)))
print(obs.shape)
return obs
and set the batch size and n_particles to the same number
pyro.clear_param_store()
fn = fn1
batch_size = 10
n_particles = 10
elbo = Trace_ELBO(num_particles=n_particles, vectorize_particles=True)
elbo.differentiable_loss(fn, AutoDiagonalNormal(fn))
it samples 10 cov_factors, and 10 times 10 batches of observations, as expected (there’s a singleton dimension, but I guess it can just be squeezed afterwards):
torch.Size([10, 3, 2])
torch.Size([10, 1, 10, 3])
tensor(340.5119, grad_fn=)
If I try having different batch size and n_particles,
pyro.clear_param_store()
fn = fn1
batch_size = 10
n_particles = 16
elbo = Trace_ELBO(num_particles=n_particles, vectorize_particles=True)
elbo.differentiable_loss(fn, AutoDiagonalNormal(fn))
it samples the cov_factor,
torch.Size([16, 3, 2])
but instead of getting a torch.Size([16, 1, 10, 3]) obs tensor, I get a shape mismatch error:
~/anaconda3/lib/python3.7/site-packages/pyro/infer/elbo.py in wrapped_fn(*args, **kwargs)
136 with pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting):
--> 137 return fn(*args, **kwargs)
138
<ipython-input-16-a841b8ba304c> in fn1()
7 print(cov_factor.shape)
----> 8 obs = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(3),cov_factor=cov_factor,cov_diag=torch.ones(3)))
9 print(obs.shape)
~/anaconda3/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
112 # apply the stack and return its return value
--> 113 apply_stack(msg)
114 return msg["value"]
~/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
192
--> 193 frame._process_message(msg)
194
~/anaconda3/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
14 super(PlateMessenger, self)._process_message(msg)
---> 15 return BroadcastMessenger._pyro_sample(msg)
16
~/anaconda3/lib/python3.7/contextlib.py in inner(*args, **kwds)
73 with self._recreate_cm():
---> 74 return func(*args, **kwds)
75 return inner
~/anaconda3/lib/python3.7/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
58 raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
---> 59 f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
60 target_batch_shape[f.dim] = f.size
ValueError: Shape mismatch inside plate('N') at site obs dim -1, 10 vs 16
If I try a model without nested plates, both settings go well:
def fn2():
with pyro.plate('N', 100, subsample_size=batch_size, dim=-1):
cov_factor = torch.ones(3,2)
print(cov_factor.shape)
obs = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(3),cov_factor=cov_factor,cov_diag=torch.ones(3)))
print(obs.shape)
return obs
pyro.clear_param_store()
fn = fn2
batch_size = 10
n_particles = 10
elbo = Trace_ELBO(num_particles=n_particles, vectorize_particles=True)
elbo.differentiable_loss(fn, AutoDiagonalNormal(fn))
I get n_particles batches of size 10 each:
torch.Size([10, 10, 3]) /
torch.Size([16, 10, 3])
If I have a model without nested plates, but use expand and to_event instead,
def fn3():
cov_factor = pyro.sample('cov_factor', dist.Normal(0.,1.).expand([3,2]).to_event(0).to_event(1))
with pyro.plate('N', 100, subsample_size=batch_size, dim=-1):
print(cov_factor.shape)
obs = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(3),cov_factor=cov_factor,cov_diag=torch.ones(3)))
print(obs.shape)
return obs
pyro.clear_param_store()
fn = fn3
batch_size = 10
n_particles = 10
elbo = Trace_ELBO(num_particles=n_particles, vectorize_particles=True)
elbo.differentiable_loss(fn, AutoDiagonalNormal(fn))
I get 10 cov_factors, and 10 x 10 batches, no singleton dimension.
torch.Size([10, 3, 2])
torch.Size([10, 10, 3])
but different batch size and n_particles fails like before with
ValueError: Shape mismatch inside plate('N') at site obs dim -1, 10 vs 16
Here I would expect it to just give me a torch.Size([16, 10, 3]).
What is going wrong?