Hello everyone,

I have a question about using a deterministic function inside a NumPyro model. A typical NumPyro model looks like:

```
def model(x, y=None):
a = numpyro.sample("a", dist.Normal(0.0, 0.2))
b = numpyro.sample("b", dist.Normal (1.0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
mu = f(x, a, b) # some deterministic function, e.g. mu = a*x + b
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
```

My question is whether it’s possible to use a more complex deterministic Python program instead of an analytical expression, which takes sampled parameters and `x`

as inputs and outputs a result in the form `jax.ndarray`

? Are there any limitations that I should be aware of?

Thank you in advance for your help!