Hi,
I couldnt find a post/forum on subsampling in manual guide apart from the SVI Part II tutorial, which uses a sequential plate instead.
Here is a simple model with the same structure but in vectorized plate form(which didnt work). The autoguide work for this example, when I put the subsampling in the model(but the tutorial say I should put the subsampling in the guide). Thanks
def model(X, t, ls):
# Additional prior on z
with numpyro.plate("n_topics", 15):
b_shape = (15, X.shape[1])
b = numpyro.sample("b", dist.Normal(jnp.zeros(b_shape), jnp.ones(b_shape)).to_event(1))
with numpyro.plate("n_cells", X.shape[0], subsample_size=1024):
z_shape = (X.shape[0], 15)
z = numpyro.sample('z', dist.Normal(jnp.zeros(z_shape), jnp.ones(z_shape)).to_event(1))
z = jax.nn.softmax(z, axis = 1)
mean = jax.nn.softmax(z @ b, axis = 1)
ls = ls.reshape(-1, 1)
numpyro.sample("X", dist.Poisson(ls * mean).to_event(1), obs = X)
def guide(X, t, ls):
with numpyro.plate("n_topics", 15):
b_shape = (15, X.shape[1])
b_loc = numpyro.param("b_s_loc", jnp.zeros(b_shape))
b = numpyro.sample("b_s", dist.Delta(b_loc).to_event(1))
with numpyro.plate("n_cells", X.shape[0], subsample_size = 1024) as ind:
# print(ind[1])
# t = t[ind]
# t_0 = t == 0
# t_1 = t == 1
z_shape = (X.shape[0], 15)
z_loc = numpyro.param('z_loc', jnp.zeros(z_shape))
z = numpyro.sample('z', dist.Delta(z_loc[ind]).to_event(1))
# guide = AutoDelta(model)