Speeding up compilation time for DiscreteHMCGibbs

Hi all, I am running into a similar problem as here, where I am attempting to speed up the time taken to run a DiscreteHMCGibbs kernel. I have a model in which the latent parameters change dimension so needs recompiling at each change. After profiling, I found the compilation function to be taking most of the runtime:

I was wondering if you had any suggestions to improve runtime/compilation speed? My current thoughts are that maybe hand-deriving the gradients and providing that would help reduce the compilation time, but not sure how much that would speed up the process.

Thank you for any help!

I think one solution is to use mask to make the dimensions consistent, e.g.
the function lambda x: x.sum() will be recompiled when x=jnp.ones(5) in the first run and x=jnp.ones(3) in the second run. To trigger compilation, you can simply change the second x to jnp.array([1., 1., 1., 0., 0.]).

Thanks for the response. I was thinking this might be the only way to achieve speed up, although I was hoping there might be a way to avoid capping the model structure, allowing it to be flexible enough to be determined by the data.

Just to check I understand how to pad with zeros to ensure the model doesn’t recompile: if I set the maximum array size to 100, but only need the first 5 values in the first evaluation, then I still need to sample the unobserved variable with size 100? So instead of

ii = numpyro.sample('ii', dist.Categorical(probs=np.tile([0.5,0.5],(5,1))))

I would need to write:

ii = numpyro.sample('ii', dist.Categorical(probs=np.tile([0.5,0.5],(100,1))))

and then mask the calculation of the log_prob?

Also, is there a way to see when the model is recompiling? Out of interest I just defined the discrete HMC kernel and MCMC objects, then ran mcmc.run(random.PRNGKey(0),data) three times consecutively but there was no change in time taken:

sample: 100%|████████████████████████████████████████████████████████████| 1001/1001 [00:08<00:00, 118.43it/s, 1023 steps of size 5.71e-04. acc. prob=1.00]
sample: 100%|████████████████████████████████████████████████████████████| 1001/1001 [00:08<00:00, 120.68it/s, 1023 steps of size 5.71e-04. acc. prob=1.00]
sample: 100%|████████████████████████████████████████████████████████████| 1001/1001 [00:08<00:00, 123.40it/s, 1023 steps of size 5.71e-04. acc. prob=1.00]

I would have expected that the second two runs would have been slightly faster as the model has already been compiled. I assume there must be an issue with recompilation in my model but I’m unsure where to look, so was hoping there might be a way to see when recompilation is being triggered.

Sorry, I think I misunderstood your question. It seems that you are running different MCMC samplers with different latent sizes? Please ignore my last comment about masking. Currently, we can only trigger compiling if the model structure (latent shapes, variable dependency,…) is the same and the data shape is the same.

I think I am misunderstanding what is happening. My assumption was that the compilation is triggered every time the model structure changes, not when it stays the same. Is that the case?

You’re right. I meant the reverse. T__T

Okay, that makes sense, so it looks like the only way to improve runtime speed for my application is to focus on the compilation.

Are there any general guidelines for improving compilation time? In particular, is it possible to provide gradient information as opposed to using AutoDiff? And if so, would this help speed up the compilation process?

Yes, it is possible to define custom derivative in jax - see this tutorial. I don’t think that it will save some compiling time but it is good to try. (I don’t have tips when dealing with jax problems with varying shapes - I’ll try to use masking if possible).

Great, thank you very much for your help! Will give this a go and see if it helps.