SVI, autoguides, subsampling and initial values

Hi everyone,

I have a hierarchical time series model that estimates per subject “baselines” from multiple measurements per subject. I want to set initial values to the “baselines” for each subject. However, when using SVI with an autoguide and subsampling this gives an “Incompatible shapes for broadcasting” error.

Here’s a MRE:

N = 1000
m = 10

baselines = np.random.normal(0, 1, size=N)
y = baselines[:, None] + np.random.normal(0, 1, size=(N, m))

def create_plates(y, subsample_size=None):
    return numpyro.plate("subjects", N, subsample_size=subsample_size)

def model(y, subsample_size=None):
    y = jnp.array(y)
    N, m = y.shape
    subj_plate = create_plates(y, subsample_size)
    with subj_plate as ind:
        baselines = numpyro.sample("baselines", numpyro.distributions.Normal(0, 1))
        baselines = jnp.repeat(baselines[:, None], m, axis=1)
        numpyro.sample(
            "y", numpyro.distributions.Normal(baselines, 1).to_event(1), obs=y[ind]
        )

guide = numpyro.infer.autoguide.AutoNormal(
    model,
    create_plates=create_plates,
    init_loc_fn=numpyro.infer.util.init_to_value(values={"baselines": y[:, 0]}), #set init values, without this the code runs fine
)

optimizer = optax.adam(1e-2)

svi = numpyro.infer.SVI(model, guide, optimizer, loss=numpyro.infer.Trace_ELBO())

svi_results = svi.run(jax.random.PRNGKey(0), num_steps=10000, y=y, subsample_size=100)

This gives:

ValueError: Incompatible shapes for broadcasting: shapes=[(100,), (1000,)]

So the init values (which are shape N=1000) give a mismatch with the subsample_size=100… Any idea how I can set init values and use subsampling? Or am I approaching this totally wrong?

I think this is a limitation of the current implementation. Could you create a feature request? I can’t think of a solution for now.

Will do!