I am trying to understand how numpyro and JIT work together. I typically run HMC in the following pattern
def run_posterior(self): nuts_kernel = NUTS(self.model_func) mcmc = MCMC( nuts_kernel, num_samples=10000, num_warmup=10000, num_chains=10, chain_method="parallel", progress_bar=False, ) mcmc.run(rng_key, **self.model_args) trace = az.from_numpyro(mcmc)
self.model_func takes in
self.model_args as its parameters. My question is - when I execute mcmc.run, is the sampling code jitted by numpyro automatically or do I have to jit it explicitly (it seems like the former)? Does it treat all the inputs of self.model_func as static argument?
An additional confusion I have is that, suppose self.model_func calls another function
self.anotherfunction, how are the inputs of
self.anotherfunction treated in jit? are they treated as static?
Thank you so much!