I have a question about using a deterministic function inside a NumPyro model. A typical NumPyro model looks like:
def model(x, y=None): a = numpyro.sample("a", dist.Normal(0.0, 0.2)) b = numpyro.sample("b", dist.Normal (1.0, 0.5)) sigma = numpyro.sample("sigma", dist.Exponential(1.0)) mu = f(x, a, b) # some deterministic function, e.g. mu = a*x + b numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
My question is whether it’s possible to use a more complex deterministic Python program instead of an analytical expression, which takes sampled parameters and
x as inputs and outputs a result in the form
jax.ndarray? Are there any limitations that I should be aware of?
Thank you in advance for your help!