Confusion about subsampling in manual guide

Hi,

I couldnt find a post/forum on subsampling in manual guide apart from the SVI Part II tutorial, which uses a sequential plate instead.

Here is a simple model with the same structure but in vectorized plate form(which didnt work). The autoguide work for this example, when I put the subsampling in the model(but the tutorial say I should put the subsampling in the guide). Thanks

def model(X, t, ls):

    # Additional prior on z
    with numpyro.plate("n_topics", 15):
        b_shape = (15, X.shape[1])
        b = numpyro.sample("b", dist.Normal(jnp.zeros(b_shape), jnp.ones(b_shape)).to_event(1))

    with numpyro.plate("n_cells", X.shape[0], subsample_size=1024):
        
        z_shape = (X.shape[0], 15) 
        z = numpyro.sample('z', dist.Normal(jnp.zeros(z_shape), jnp.ones(z_shape)).to_event(1))
        z = jax.nn.softmax(z, axis = 1)
        
        mean = jax.nn.softmax(z @ b, axis = 1)
        ls = ls.reshape(-1, 1)
        numpyro.sample("X", dist.Poisson(ls * mean).to_event(1), obs = X)

def guide(X, t, ls):
    with numpyro.plate("n_topics", 15):
        b_shape = (15, X.shape[1])
        b_loc = numpyro.param("b_s_loc", jnp.zeros(b_shape))
        b = numpyro.sample("b_s", dist.Delta(b_loc).to_event(1))

    with numpyro.plate("n_cells", X.shape[0], subsample_size = 1024) as ind:
        # print(ind[1])
        # t = t[ind]
        # t_0 = t == 0
        # t_1 = t == 1
        z_shape = (X.shape[0], 15) 
        z_loc = numpyro.param('z_loc', jnp.zeros(z_shape))
        z = numpyro.sample('z', dist.Delta(z_loc[ind]).to_event(1))
      
# guide = AutoDelta(model)

I realise it works, it was just my initialization that was the problem