Hi NumPyro,
I was wondering if there is any difference between the following two options.
Define a vector of independent Normals using plate
:
with numpyro.plate("N", 10):
samples = numpyro.sample(
'samples',
dist.Normal(loc=0, scale=1)
)
Or creating the vectorized distribution of Normals from a vector of parameters:
samples = numpyro.sample(
'samples',
dist.Normal(loc=jnp.zeros(10), scale=jnp.ones(10))
)
?
Sampling
From a sampling perspective there does not seem to be any difference:
def model():
with numpyro.plate("N", 10):
samples = numpyro.sample(
'samples',
dist.Normal(loc=0, scale=1)
)
prior_predictive = Predictive(
model,
num_samples=1
)
prior_predictions = prior_predictive(
jax.random.PRNGKey(0),
)
prior_predictions
# {'samples': DeviceArray([[-0.38812608, -0.04487164, -2.0427258 , 0.07932311,
# 0.33349916, 0.7959976 , -1.4411978 , -1.6929979 ,
# -0.37369204, -1.5401139 ]], dtype=float32)}
vs
def model():
samples = numpyro.sample(
'samples',
dist.Normal(loc=jnp.zeros(10), scale=jnp.ones(10))
)
prior_predictive = Predictive(
model,
num_samples=1
)
prior_predictions = prior_predictive(
jax.random.PRNGKey(0),
)
prior_predictions
# {'samples': DeviceArray([[-0.38812608, -0.04487164, -2.0427258 , 0.07932311,
# 0.33349916, 0.7959976 , -1.4411978 , -1.6929979 ,
# -0.37369204, -1.5401139 ]], dtype=float32)}
Inference methods
However, I was wondering if any of the inference methods or other functionalities (e.g. SVI guides) use the plate
context manager to define specific behaviors?