Hello everyone,
I have a question about using a deterministic function inside a NumPyro model. A typical NumPyro model looks like:
def model(x, y=None):
a = numpyro.sample("a", dist.Normal(0.0, 0.2))
b = numpyro.sample("b", dist.Normal (1.0, 0.5))
sigma = numpyro.sample("sigma", dist.Exponential(1.0))
mu = f(x, a, b) # some deterministic function, e.g. mu = a*x + b
numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
My question is whether it’s possible to use a more complex deterministic Python program instead of an analytical expression, which takes sampled parameters and x
as inputs and outputs a result in the form jax.ndarray
? Are there any limitations that I should be aware of?
Thank you in advance for your help!