Thanks. But I think I did not explain my issue clearly. So, I decided to make a toy problem with minimal data to show the issue I am facing now.
 I generate a fake data to perform variational inference. I adopted the pymc3 linear regression example here to numpyro.
 I have two exact same models. One with and the other without
submsample
argument in the numpyro.plate
.
 Both models are giving the correct results after training (saved in
svi_result.params
).
 My issue is: Let’s say I want the posterior prediction for 121
x
values. And I draw 100 samples from posteriors. Then I have the following shapes for prediction of y

shape of y_pred for model without subsample is (100, 121, 1)
. Which is correct.

shape of y_pred for model with subsample is (100, 16, 1)
. Which is confusing.
My question is: how can I get the same shape from posterior prediction from the two models?
Thanks
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive, SVI, autoguide, init_to_sample
from numpyro.infer import Trace_ELBO
numpyro.set_platform('cpu')
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
# Generate fake data
size = 200
true_intercept = 1
true_slope = 2
x = jnp.linspace(0, 1, size).reshape(1, 1)
# y = intercept + w*x
true_regression_line = true_intercept + true_slope * x
# add noise
y = true_regression_line + random.normal(rng_key, shape=(size, 1))
# Define numpyro models
def model_without_subsample(x=None, y=None):
# Model withour subsample
sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=10))
intercept = numpyro.sample("Intercept", dist.Normal(loc=0, scale=20))
w = numpyro.sample("w", dist.Normal(loc=0, scale=20))
with numpyro.plate("data", x.shape[0]):
numpyro.sample("obs", dist.Normal(intercept + w * x, scale=sigma).to_event(1), obs=y)
def model_with_subsample(x=None, y=None):
# Model with subsample and batch size of 16
sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=10))
intercept = numpyro.sample("Intercept", dist.Normal(loc=0, scale=20))
w = numpyro.sample("w", dist.Normal(loc=0, scale=20))
with numpyro.plate("data", x.shape[0], subsample_size=16) as ind:
if y is not None:
batch_y = y[ind]
else:
batch_y = None
if x is not None:
batch_x = x[ind]
else:
batch_x = None
numpyro.sample("obs", dist.Normal(intercept + w * batch_x, scale=sigma).to_event(1), obs=batch_y)
# SVI for model without subsample
guide = autoguide.AutoNormal(
model_without_subsample, init_loc_fn=init_to_sample)
optimizer = numpyro.optim.Adam(step_size=1e3)
svi = SVI(model_without_subsample, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 100000, x, y)
# Posterior Predictive for model without subsample
posterior_predictive = Predictive(model=model_without_subsample, guide=guide,
params=svi_result.params,
num_samples=100)
y_pred = posterior_predictive(rng_key_predict, x=x[:121])['obs']
print(f'shape of y_pred for model without subsample is {y_pred.shape}')
# SVI for model with subsample
guide = autoguide.AutoNormal(model_with_subsample, init_loc_fn=init_to_sample)
optimizer = numpyro.optim.Adam(step_size=1e3)
svi = SVI(model_with_subsample, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 100000, x, y)
# Posterior Predictive for model with subsample
posterior_predictive = Predictive(model=model_with_subsample, guide=guide,
params=svi_result.params,
num_samples=100)
y_pred = posterior_predictive(rng_key_predict, x=x[:121])['obs']
print(f'shape of y_pred for model with subsample is {y_pred.shape}')