Why does NUTS sampler influence plate indices?

I want to do NUTS sampling on a large dataset, and thus want to use subsamples. If I define a model and evaluate random subsamples (by different indices) everything seems to work.

Note, the print function.

def model(data):
    with pyro.plate('data', size=data.shape[0], subsample_size=5) as idx:
        print(idx)
        pyro.sample('obs', dist.Normal(data[idx], 1))
        

for _ in range(5):
    model(torch.arange(0, 10, dtype=torch.float32))

>> out:
tensor([5, 2, 1, 9, 6])
tensor([9, 8, 1, 2, 4])
tensor([3, 7, 0, 8, 5])
tensor([7, 2, 4, 1, 5])
tensor([4, 9, 7, 1, 5])

However, when we start sampling using NUTS, the indices are constant.

nuts_kernel = NUTS(model)
posterior = MCMC(nuts_kernel,
                num_samples=100,
                warmup_steps=100).run(torch.arange(0, 10, dtype=torch.float32))

>> out:
tensor([3, 6, 4, 5, 8])
tensor([2, 6, 3, 4, 9])
tensor([2, 6, 3, 4, 9])
tensor([2, 6, 3, 4, 9])
tensor([2, 6, 3, 4, 9])
tensor([2, 6, 3, 4, 9])
tensor([2, 6, 3, 4, 9])

as inference algorithms HMC/NUTS do not support data subsampling. so i’m not sure what you’re trying to do?

Ah, Then I should rely on VI. I was hoping I could sample for a large dataset by subsampling.