I want to do NUTS sampling on a large dataset, and thus want to use subsamples. If I define a model and evaluate random subsamples (by different indices) everything seems to work.
Note, the print
function.
def model(data):
with pyro.plate('data', size=data.shape[0], subsample_size=5) as idx:
print(idx)
pyro.sample('obs', dist.Normal(data[idx], 1))
for _ in range(5):
model(torch.arange(0, 10, dtype=torch.float32))
>> out:
tensor([5, 2, 1, 9, 6])
tensor([9, 8, 1, 2, 4])
tensor([3, 7, 0, 8, 5])
tensor([7, 2, 4, 1, 5])
tensor([4, 9, 7, 1, 5])
However, when we start sampling using NUTS, the indices are constant.
nuts_kernel = NUTS(model)
posterior = MCMC(nuts_kernel,
num_samples=100,
warmup_steps=100).run(torch.arange(0, 10, dtype=torch.float32))
>> out:
tensor([3, 6, 4, 5, 8])
tensor([2, 6, 3, 4, 9])
tensor([2, 6, 3, 4, 9])
tensor([2, 6, 3, 4, 9])
tensor([2, 6, 3, 4, 9])
tensor([2, 6, 3, 4, 9])
tensor([2, 6, 3, 4, 9])