Jit subfunctions of a model?

I have a model in which there are several deterministic (but data-dependent) transformations. Some of these transformations slow the model down significantly.

Should the user try to jit such functions themselves, or is it better to leave the functions as-is in the model definition and let Numpryo jit the whole thing?

Toy code:


def model(data):
    a = numpyro.sample("a", ...)
    b = numpyro.sample("b", ...)

    c = numpyro.deterministic("c", some_function(a, b))


def some_function(a, b, c: float = 1.0, d: float = 0.0):
   ...
   return (value)

In general, should some_function be jitted (with e.g. static_argnames=["c", "d"]) if it’s possible to do so? Or is it better to leave things un-jitted and let NumPyro try to optimize over the whole model definition?

I think jax prefers we leave things un-jitted, at least in the past.

I see, thanks for your reply!