Hello everyone,
I am trying to understand how numpyro and JIT work together. I typically run HMC in the following pattern
def run_posterior(self):
nuts_kernel = NUTS(self.model_func)
mcmc = MCMC(
nuts_kernel,
num_samples=10000,
num_warmup=10000,
num_chains=10,
chain_method="parallel",
progress_bar=False,
)
mcmc.run(rng_key, **self.model_args)
trace = az.from_numpyro(mcmc)
where self.model_func
takes in self.model_args
as its parameters. My question is - when I execute mcmc.run, is the sampling code jitted by numpyro automatically or do I have to jit it explicitly (it seems like the former)? Does it treat all the inputs of self.model_func as static argument?
An additional confusion I have is that, suppose self.model_func calls another function self.anotherfunction
, how are the inputs of self.anotherfunction
treated in jit? are they treated as static?
Thank you so much!
Alan