Is there any difference between using plate or a vectorized distribution?

Hi NumPyro,

I was wondering if there is any difference between the following two options.

Define a vector of independent Normals using plate:

with numpyro.plate("N", 10):
    samples = numpyro.sample(
        'samples',
        dist.Normal(loc=0, scale=1)
    )

Or creating the vectorized distribution of Normals from a vector of parameters:

samples = numpyro.sample(
    'samples',
    dist.Normal(loc=jnp.zeros(10), scale=jnp.ones(10))
)

?

Sampling

From a sampling perspective there does not seem to be any difference:

def model():
    with numpyro.plate("N", 10):
        samples = numpyro.sample(
            'samples',
            dist.Normal(loc=0, scale=1)
        )
        
prior_predictive = Predictive(
    model,
    num_samples=1
)
prior_predictions = prior_predictive(
    jax.random.PRNGKey(0),
)
prior_predictions
# {'samples': DeviceArray([[-0.38812608, -0.04487164, -2.0427258 ,  0.07932311,
#                 0.33349916,  0.7959976 , -1.4411978 , -1.6929979 ,
#                -0.37369204, -1.5401139 ]], dtype=float32)}

vs

def model():
    samples = numpyro.sample(
        'samples',
        dist.Normal(loc=jnp.zeros(10), scale=jnp.ones(10))
    )
        
prior_predictive = Predictive(
    model,
    num_samples=1
)
prior_predictions = prior_predictive(
    jax.random.PRNGKey(0),
)
prior_predictions
# {'samples': DeviceArray([[-0.38812608, -0.04487164, -2.0427258 ,  0.07932311,
#                 0.33349916,  0.7959976 , -1.4411978 , -1.6929979 ,
#                -0.37369204, -1.5401139 ]], dtype=float32)}

Inference methods

However, I was wondering if any of the inference methods or other functionalities (e.g. SVI guides) use the plate context manager to define specific behaviors?

There are some inference methods that leverage plates like:

  • enumeration: needs plate to allocate enum dimension properly
  • TraceGraph_ELBO: needs plate to exploit independency
  • subsample: needs plate to perform subsampling

Though we relaxed plate requirement in some cases, it is best to always use plate to declare batch dimensions. In a (num)pyro program, there is no benefit of not using plate. :slight_smile:

1 Like

Thanks for the clarification @fehiepsi . What do you mean with the following?

What do you mean with a “(num)pyro program”?

It is a Pyro program with PyTorch backend or NumPyro program with Jax backend. By program, I mean any code written by using either Pyro or NumPyro.