Deterministic function inside NumPyro's model

Hello everyone,

I have a question about using a deterministic function inside a NumPyro model. A typical NumPyro model looks like:

    def model(x, y=None):
        a = numpyro.sample("a", dist.Normal(0.0, 0.2))
        b = numpyro.sample("b", dist.Normal (1.0, 0.5))
        sigma = numpyro.sample("sigma", dist.Exponential(1.0))
        mu = f(x, a, b)  # some deterministic function, e.g. mu = a*x + b
        numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)

My question is whether it’s possible to use a more complex deterministic Python program instead of an analytical expression, which takes sampled parameters and x as inputs and outputs a result in the form jax.ndarray? Are there any limitations that I should be aware of?

Thank you in advance for your help!

if you want to use an inference algorithm that uses gradients (like hmc) you need to make sure everything is differentiable. unless you set things up so your arbitrary python program also computes derivatives, you’ll be out of luck

1 Like