SVI: Normal in model, MultivariateNormal in guide

In my model I have a plate over batch with dim=-1 and a plate over components of a normal with dim=-2:

def model():
  with pyro.plate('mini_batch', size_mini_batch, dim=-1):
      # sample other latents using mini_batch dim=-1
      with pyro.plate('components',n_components,dim=-2):
        dist = dist.Normal(loc,scale)
        a = pyro.sample('a', dist, obs=obs)
        return a

The model returns produces samples of a with shape [n_components,size_mini_batch].

Then in the guide I have:

  with pyro.plate('mini_batch', size_mini_batch, dim=-1):
    lam = net(mini_batch) # mini_batch on dim 0. 
    lam_reshape = lam.transpose(0,-1) # switch mini_batch dim to -1. lam_reshape shape [2*n_components, size_mini_batch]
    loc, log_scale = lam_reshape[:n_components], lam_reshape[n_components:2 * n_components]
    scale = torch.exp(log_scale)

    with pyro.plate('components', n_components, dim=-2):
      dist = dist.Normal(loc, scale)
      a = pyro.sample("a", dist, obs=obs) 
      return a

I’d like to extend the Normal distribution in the guide to a MultivariateNormal, where the covariance matrix is no longer diagonal.

Usually in pure pytorch I’d put the batch dimension on the 0th dimension, and then sample like so:

MultivariateNormal(mv_loc,mv_cov), with a mv_loc.shape of [size_mini_batch,n_components] and a mv_cov.shape of [size_mini_batch,n_components,n_components]. But the shape or the sample ([size_mini_batch,n_components]) doesn’t match what’s in the model ([n_components,size_mini_batch])

What are my options here?

  • What shape should mv_loc and mv_cov have so that it parallelizes over dim=-1? Is this possible?
  • I’m fairly locked into the mini_batch dim=-1 with how my model code is vectorized (lots’ not shown here). However, was this a bad design choice? Can I just step out of the mini_batch plate for sampling a, and use dim=0 for it and keep everything else the same?

Hi @geoffwoollard,
The easiest option might be to change your model to avoid the components plate and instead use .to_event(). Modifying your current model you could:

def model_v2():
    # shapes as in your original model
    assert loc.shape == (n_components, size_mini_batch)
    assert scale.shape == (n_components, size_mini_batch)

    # now reshape
    with pyro.plate("mini_batch", size_mini_batch, dim=-1):
        a = pyro.sample("a.T", dist.Normal(loc.T, scale.T).to_event(1), obs=obs.T).T
    return a

Then you can use a simple multivariate guide

def guide_v2():
    with pyro.plate("mini_batch", size_mini_batch, dim=-1):
        lam = net(mini_batch) # mini_batch on dim 0.
        assert lam.shape == (size_mini_batch, 2 * n_components)
        loc, log_scale = lam.reshape(mini_batch, 2, n_components).unbind(-2)
        scale_tril = log_scale.exp().diag_embed()
        pyro.sample("a", dist.MultivariateNormal(loc, scale_tril=scale_tril)

Alternatively if you really want to support the “components” plate, you can follow the pattern of Pyro’s AutoContinuous guides like AutoMultivariateNormal: sample an auxiliary joint random variable with is_auxiliary=True, then reshape and sample the user-facing variable from a Delta distribution:

def guide():
    with pyro.plate('mini_batch', size_mini_batch, dim=-1):
        lam = net(mini_batch) # mini_batch on dim 0.
        assert lam.shape == (size_mini_batch, 2 * n_components)
        loc, log_scale = lam.reshape(mini_batch, 2, n_components).unbind(-2)
        scale_tril = log_scale.exp().diag_embed()
        a_aux = pyro.sample(
            "a_aux",
            dist.MultivariateNormal(loc, scale_tril=scale_tril), 
            infer={"is_auxiliary": True},  # <--- declare an auxiliary guide variable
        )
        assert a_aux.shape == (size_mini_batch, n_components)
        a = a_aux.T
        with pyro.plate('components', n_components, dim=-2):
            pyro.sample("a", dist.Delta(a))  # <--- sample from a delta