Thanks for the response. I was thinking this might be the only way to achieve speed up, although I was hoping there might be a way to avoid capping the model structure, allowing it to be flexible enough to be determined by the data.
Just to check I understand how to pad with zeros to ensure the model doesn’t recompile: if I set the maximum array size to 100, but only need the first 5 values in the first evaluation, then I still need to sample the unobserved variable with size 100? So instead of
ii = numpyro.sample('ii', dist.Categorical(probs=np.tile([0.5,0.5],(5,1))))
I would need to write:
ii = numpyro.sample('ii', dist.Categorical(probs=np.tile([0.5,0.5],(100,1))))
and then mask the calculation of the log_prob?
Also, is there a way to see when the model is recompiling? Out of interest I just defined the discrete HMC kernel and MCMC objects, then ran mcmc.run(random.PRNGKey(0),data)
three times consecutively but there was no change in time taken:
sample: 100%|████████████████████████████████████████████████████████████| 1001/1001 [00:08<00:00, 118.43it/s, 1023 steps of size 5.71e-04. acc. prob=1.00]
sample: 100%|████████████████████████████████████████████████████████████| 1001/1001 [00:08<00:00, 120.68it/s, 1023 steps of size 5.71e-04. acc. prob=1.00]
sample: 100%|████████████████████████████████████████████████████████████| 1001/1001 [00:08<00:00, 123.40it/s, 1023 steps of size 5.71e-04. acc. prob=1.00]
I would have expected that the second two runs would have been slightly faster as the model has already been compiled. I assume there must be an issue with recompilation in my model but I’m unsure where to look, so was hoping there might be a way to see when recompilation is being triggered.