Is this optimal syntax for numpyro model input parameters?

I’m writing a numpyro model that calls a function whose argument is a single jax.np array giving lots of model parameters. I only want to vary/infer a subset of those parameters with numpyro. Currently my numpyro model has this call signature:


# initialize full list of input parameters for f(rand_args)
# of these, only a few need to be varied/inferred in numpyro
rand_args = jnp.array([0.0,1.0,2.0,3.0,4.0])

def model(rand_args, obs_data, obs_sigma):

   # sample our one free parameter 
   alpha = numpyro.sample('alpha',dist.TruncatedNormal(low=0.0,loc=1.0,scale=0.5))

    # update rand_args with the random variable
    rand_args = rand_args.at[0].set(alpha)

    # call our main function to predict output given rand_args 
    pred = f(rand_args)

   # likelihood
    numpyro.sample('obs', dist.LogNormal(jnp.log(pred), obs_sigma), obs=obs_data)

Is the above usage optimal where rand_args is a global jax array that is an input to the numpyro model and is updated inside the model with .at? Or would it cause a slowdown, issues with sampling, etc.?

Is there a more elegant, simpler way to do this?

Anyone? Is it possible that this can cause the numpyro model to be recompiled?

I’m finding that my function f(rand_args) is evaluated in 0.5 sec, and its gradient in 0.05 sec. But the NUTS progress bar says upwards of ~60-120 seconds per iteration, and one chain is always getting stuck while 3 finish. I’m wondering if either the numpyro model or my f function is being recompiled each time due to the way I’m updating rand_args inside the numpyro model…

How do I check how many times the numpyro model got recompiled? I know I can do it for JAX functions with f._cache_size().

Maybe add some print statements to your model? Compiled code won’t print anything.

Is the above usage optimal where rand_args is a global jax array that is an input to the numpyro model and is updated inside the model with .at ? Or would it cause a slowdown, issues with sampling, etc.?

rand_args is a local array (input of your model). The global one won’t be updated. rand_args.at[0].set(alpha) will create a new array.

Thanks @fehiepsi ! Is there any need to provide rand_args as an input in this minimal example? Would jit_model_args speed up the sampling, or rather what use case is that intended for?

When I print(pred) inside the model, the first iteration prints the actual values of the array, and all subsequent print statements during subsequent numpyro iterations say

Traced<ShapedArray(float64[50])>with<BatchTrace(level=3/0)> with
  val = Traced<ShapedArray(float64[5,50])>with<DynamicJaxprTrace(level=2/0)>
  batch_dim = 0

meaning the model is indeed compiled. However what I am asking is – is it possible to check if each new iteration is causing the model to be recompiled? Maybe because of the input rand_args? I just confirmed that my jitted f function has f._cache_size() = 1 so it is not being recompiled so my guess is I’m OK. It would be nice if the numpyro mcmc object had a similar attribute/method to check how many times compilation happened.

mcmc = MCMC(AIES(model),num_warmup=10,num_samples=10,num_chains=10,chain_method='vectorized')

mcmc.run(key(1), rand_args=rand_args, obs_data=obs_data, obs_sigma=obs_sigma)

mcmc.print_summary()

By “new iteration”, do you mean that new mcmc.run(...)? If so, jit_model_args will be helpful if you provide a new rand_args with same shape and dtype.

The number of print statements is the number of compiling time.