Batch size vs. number of ELBO particles confusion

Hi,

I am having some trouble with setting data batch sizes and the number of ELBO particles.
If I have a model with two nested plates, and another plate outside those:

def fn1():
    with pyro.plate('D', 3, dim=-1):
        with pyro.plate('K', 2, dim=-2):
            cov_factor = pyro.sample('cov_factor', dist.Normal(0.,1.))
        cov_factor = cov_factor.transpose(-2,-1)
    with pyro.plate('N', 100, subsample_size=batch_size, dim=-1):
        print(cov_factor.shape)
        obs = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(3),cov_factor=cov_factor,cov_diag=torch.ones(3)))
        print(obs.shape)
        return obs

and set the batch size and n_particles to the same number

pyro.clear_param_store()
fn = fn1
batch_size = 10
n_particles = 10
elbo = Trace_ELBO(num_particles=n_particles, vectorize_particles=True)
elbo.differentiable_loss(fn, AutoDiagonalNormal(fn))

it samples 10 cov_factors, and 10 times 10 batches of observations, as expected (there’s a singleton dimension, but I guess it can just be squeezed afterwards):
torch.Size([10, 3, 2])
torch.Size([10, 1, 10, 3])
tensor(340.5119, grad_fn=)

If I try having different batch size and n_particles,

pyro.clear_param_store()
fn = fn1
batch_size = 10
n_particles = 16
elbo = Trace_ELBO(num_particles=n_particles, vectorize_particles=True)
elbo.differentiable_loss(fn, AutoDiagonalNormal(fn))

it samples the cov_factor,
torch.Size([16, 3, 2])
but instead of getting a torch.Size([16, 1, 10, 3]) obs tensor, I get a shape mismatch error:

~/anaconda3/lib/python3.7/site-packages/pyro/infer/elbo.py in wrapped_fn(*args, **kwargs)
    136             with pyro.plate("num_particles_vectorized", self.num_particles, dim=-self.max_plate_nesting):
--> 137                 return fn(*args, **kwargs)
    138 

<ipython-input-16-a841b8ba304c> in fn1()
      7         print(cov_factor.shape)
----> 8         obs = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(3),cov_factor=cov_factor,cov_diag=torch.ones(3)))
      9         print(obs.shape)

~/anaconda3/lib/python3.7/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
    112         # apply the stack and return its return value
--> 113         apply_stack(msg)
    114         return msg["value"]

~/anaconda3/lib/python3.7/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
    192 
--> 193         frame._process_message(msg)
    194 

~/anaconda3/lib/python3.7/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
     14         super(PlateMessenger, self)._process_message(msg)
---> 15         return BroadcastMessenger._pyro_sample(msg)
     16 

~/anaconda3/lib/python3.7/contextlib.py in inner(*args, **kwds)
     73             with self._recreate_cm():
---> 74                 return func(*args, **kwds)
     75         return inner

~/anaconda3/lib/python3.7/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
     58                     raise ValueError("Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
---> 59                         f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
     60                 target_batch_shape[f.dim] = f.size

ValueError: Shape mismatch inside plate('N') at site obs dim -1, 10 vs 16

If I try a model without nested plates, both settings go well:

def fn2():
    with pyro.plate('N', 100, subsample_size=batch_size, dim=-1):
        cov_factor = torch.ones(3,2)
        print(cov_factor.shape)
        obs = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(3),cov_factor=cov_factor,cov_diag=torch.ones(3)))
        print(obs.shape)
        return obs

pyro.clear_param_store()
fn = fn2
batch_size = 10
n_particles = 10
elbo = Trace_ELBO(num_particles=n_particles, vectorize_particles=True)
elbo.differentiable_loss(fn, AutoDiagonalNormal(fn))

I get n_particles batches of size 10 each:
torch.Size([10, 10, 3]) /
torch.Size([16, 10, 3])

If I have a model without nested plates, but use expand and to_event instead,

def fn3():
    cov_factor = pyro.sample('cov_factor', dist.Normal(0.,1.).expand([3,2]).to_event(0).to_event(1))
    with pyro.plate('N', 100, subsample_size=batch_size, dim=-1):
        print(cov_factor.shape)
        obs = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(3),cov_factor=cov_factor,cov_diag=torch.ones(3)))
        print(obs.shape)
        return obs
pyro.clear_param_store()
fn = fn3
batch_size = 10
n_particles = 10
elbo = Trace_ELBO(num_particles=n_particles, vectorize_particles=True)
elbo.differentiable_loss(fn, AutoDiagonalNormal(fn))

I get 10 cov_factors, and 10 x 10 batches, no singleton dimension.
torch.Size([10, 3, 2])
torch.Size([10, 10, 3])

but different batch size and n_particles fails like before with

ValueError: Shape mismatch inside plate('N') at site obs dim -1, 10 vs 16

Here I would expect it to just give me a torch.Size([16, 10, 3]).

What is going wrong?

Hi @deoxy, nice example!

.to_event() accepts the number of rightmost dims to convert to event dims, whereas it looks like you are trying to pass the dims themselves. Try replacing your code with

.to_event(2)
1 Like

Thanks @fritzo, that works! Just for completeness’ sake, is there a way to get the first model, with plates, to work as well? If I understand right, using to_event() is not enough to mark conditional independence.

I am a new user and don’t know the exact answer to your question. However, when I encountered the same shape mismatch problems, it helped me a lot to print batch_shape and event_shape of the distributions without doing any sampling to see if the tensor shapes are as expected. In your case, I would suggest printing out dist.LowRankMultivariateNormal(…).batch_shape and event_shape before you sample obs, maybe it would help you understand the tensor arithmetic a bit better.

That’s good advice! I more or less solved the problem by unsqueezing the cov_factor:

def fn1():
    with pyro.plate('D', 3, dim=-2):
        with pyro.plate('K', 2, dim=-1):
            cov_factor = pyro.sample('cov_factor', dist.Normal(0.,1.))
        if cov_factor.dim() == 3:
            cov_factor = cov_factor.unsqueeze(-3)
    with pyro.plate('N', 100, subsample_size=batch_size, dim=-1):
        print(cov_factor.shape)
        obs = pyro.sample('obs', dist.LowRankMultivariateNormal(torch.zeros(3),cov_factor=cov_factor,cov_diag=torch.ones(3)))
        print(obs.shape)
        return obs

pyro.clear_param_store()
fn = fn1
batch_size = 10
n_particles = 16
elbo = Trace_ELBO(num_particles=n_particles, vectorize_particles=True)
elbo.differentiable_loss(fn, AutoDiagonalNormal(fn))

which gives me a superfluous singleton dimension

torch.Size([16, 1, 3, 2])
torch.Size([16, 1, 10, 3])

but avoids the shape mismatch error.

In trying to understand how pyro does vectorization of particles, I came upon these threads: Parallelize ELBO computation over num_particles · Issue #791 · pyro-ppl/pyro · GitHub and https://github.com/pyro-ppl/pyro/pull/1176 where the devs discuss how to avoid double-broadcasting - this happens if I wrap my model in a plate - then the obs samples a torch.Size([16, 10, 16, 3]) and the like. But reading the code in pyro/elbo.py at dev · pyro-ppl/pyro · GitHub it looks like it’s also just a plate (in line 136). But then in the example above there is no double broadcasting.
If anyone can explain, I’d appreciate it :slight_smile:

You’re right, vectorization of particles is accomplished by a plate. Note that plate needs to have dim=-3 to avoid all of the other plates in your fn1() model. I’m not sure why wrapping your model in a plate results in double-broadcasting and a shape torch.Size([16, 10, 16, 3]). Are you sure your plate sets dim=-3?

I think the issue is at this line. cov_factor should have shape batch_shape + mvn.event_shape + (rank,). Your plate statements can take care of batch_shape but it seems that you missed declaring mvn.event_shape + (rank,). One way to do it is

cov_factor = pyro.sample('cov_factor', dist.Normal(0.,1.).expand([3, 1]).to_event(2))

assuming that mvn.event_shape == (3,) and rank = 1.

Ah, I see. I thought if I wanted a fully-factorized prior on cov_factor, I need to have 2 independent dimensions, but it’s still a single distribution, so it doesn’t make sense to use plates (I thought TraceELBO could that way exploit the fully-factorized structure). I’ve been using this wrong in several models :sweat_smile:

I guess it doesn’t matter now, but the double-broadcasting thing only happened when I set the dim on the ‘obs’ plate to -2 (to avoid the shape mismatch). So cov_factor had batch_shape (16,3,2), and with the whole model wrapped in a dim=-3 plate, the obs shape was (16,10,16,3).

Thanks for the help everyone.

@deoxy a high level point worth keeping in mind is that for SVI pyro primarily will use plate information in two ways:

  • to support subsampling
  • to compute lower-variance score-function gradients when required

however, if all your distributions are reparameterizable, the second point is moot, and so depending on details it’s not really necessary to make sure that pyro is exhaustively informed of every last conditional independence (since it often won’t matter for SVI training)

1 Like