Shape of Predictions from Model with subsampling

Hi all,

I have a model which I successfully trained with SVI and the results are good. I want to add the subsampling option to the numpyro.plate. The relevant part of the model is given bellow.
However, I have a problem with the shape of the obs variable when I draw the prior/posterior distribution from the model.

  • What I get for the prior_predictions['obs'] has the shape (100,10,1) which is (num_samples, subsample_size, event_shape). How can I get the (100,1200,1) array, since I explicitly passed the first 1200 values of x to the model.
def model(x=None, y=None):
    # . ..
    # Complex mode for find the latent variable g
    # ....
    with numpyro.plate("data", x.shape[0], subsample_size=10) as ind:
        if y is not None:
            batch_y = y[ind]
        else:
            batch_y = None
        batch_g = g[ind]
        numpyro.sample("obs", dist.Normal(
            batch_g, sigma_obs).to_event(1), obs=batch_y)


# get the data
x, y = get_data()
x = jnp.array(x)
y = jnp.array(y)
# Shape of x and y are (30_000, 10) and (30_000,1)
# Prior predictive check
prior_predictive = Predictive(model, num_samples=100)
prior_predictions = prior_predictive(rng_key, x=x[:1200])

Thanks.

I think you can change 10 to 10000 by providing subsample_size as an argument of your model.

Thanks. But I think I did not explain my issue clearly. So, I decided to make a toy problem with minimal data to show the issue I am facing now.

  • I generate a fake data to perform variational inference. I adopted the pymc3 linear regression example here to numpyro.
  • I have two exact same models. One with and the other without submsample argument in the numpyro.plate.
  • Both models are giving the correct results after training (saved in svi_result.params).
  • My issue is: Let’s say I want the posterior prediction for 121 x values. And I draw 100 samples from posteriors. Then I have the following shapes for prediction of y
    • shape of y_pred for model without subsample is (100, 121, 1). Which is correct.
    • shape of y_pred for model with subsample is (100, 16, 1). Which is confusing.

My question is: how can I get the same shape from posterior prediction from the two models?
Thanks

import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive, SVI, autoguide, init_to_sample
from numpyro.infer import Trace_ELBO

numpyro.set_platform('cpu')
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

# Generate fake data
size = 200
true_intercept = 1
true_slope = 2

x = jnp.linspace(0, 1, size).reshape(-1, 1)
# y = intercept + w*x
true_regression_line = true_intercept + true_slope * x
# add noise
y = true_regression_line + random.normal(rng_key, shape=(size, 1))

# Define numpyro models

def model_without_subsample(x=None, y=None):
    # Model withour subsample
    sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=10))
    intercept = numpyro.sample("Intercept", dist.Normal(loc=0, scale=20))
    w = numpyro.sample("w", dist.Normal(loc=0, scale=20))
    with numpyro.plate("data", x.shape[0]):
        numpyro.sample("obs", dist.Normal(intercept + w * x, scale=sigma).to_event(1), obs=y)

def model_with_subsample(x=None, y=None):
    # Model with subsample and batch size of 16
    sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=10))
    intercept = numpyro.sample("Intercept", dist.Normal(loc=0, scale=20))
    w = numpyro.sample("w", dist.Normal(loc=0, scale=20))
    with numpyro.plate("data", x.shape[0], subsample_size=16) as ind:
        if y is not None:
            batch_y = y[ind]
        else:
            batch_y = None
        if x is not None:
            batch_x = x[ind]
        else:
            batch_x = None
        numpyro.sample("obs", dist.Normal(intercept + w * batch_x, scale=sigma).to_event(1), obs=batch_y)

# SVI for model without subsample
guide = autoguide.AutoNormal(
    model_without_subsample, init_loc_fn=init_to_sample)
optimizer = numpyro.optim.Adam(step_size=1e-3)
svi = SVI(model_without_subsample, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 100000, x, y)

# Posterior Predictive for model without subsample
posterior_predictive = Predictive(model=model_without_subsample, guide=guide,
                                  params=svi_result.params,
                                  num_samples=100)
y_pred = posterior_predictive(rng_key_predict, x=x[:121])['obs']
print(f'shape of y_pred for model without subsample is {y_pred.shape}')


# SVI for model with subsample
guide = autoguide.AutoNormal(model_with_subsample, init_loc_fn=init_to_sample)
optimizer = numpyro.optim.Adam(step_size=1e-3)
svi = SVI(model_with_subsample, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 100000, x, y)

# Posterior Predictive for model with subsample
posterior_predictive = Predictive(model=model_with_subsample, guide=guide,
                                  params=svi_result.params,
                                  num_samples=100)
y_pred = posterior_predictive(rng_key_predict, x=x[:121])['obs']
print(f'shape of y_pred for model with subsample is {y_pred.shape}')

How about this

def model_with_subsample(x=None, y=None, subsample_size=None):
    ...
    with numpyro.plate("data", x.shape[0], subsample_size=subsample_size) as ind:
    ...


svi_result = svi.run(rng_key, 100000, x, y, subsample_size=20)
1 Like

Works like a charm. Thanks.