I have a model in which there are several deterministic (but data-dependent) transformations. Some of these transformations slow the model down significantly.
Should the user try to jit
such functions themselves, or is it better to leave the functions as-is in the model definition and let Numpryo jit the whole thing?
Toy code:
def model(data):
a = numpyro.sample("a", ...)
b = numpyro.sample("b", ...)
c = numpyro.deterministic("c", some_function(a, b))
def some_function(a, b, c: float = 1.0, d: float = 0.0):
...
return (value)
In general, should some_function
be jitted (with e.g. static_argnames=["c", "d"]
) if it’s possible to do so? Or is it better to leave things un-jitted and let NumPyro try to optimize over the whole model definition?