Trace and storing during fitting

Hello! we are trying to understand the trace and store an intermediate result:

For example, we have this little piece of code:

def model():
    x = numpyro.sample("x", dist.Normal(0,1).expand([5]))
    print('x =', x)
    k = x+1

mcmc = MCMC(NUTS(model), 1000, 1000)
mcmc.run(random.PRNGKey(0))

We are looking at the print: it displays the array x, but it also displays other JVPTrace with (level=2/0), …

x = [ 0.3011222 0.7353368 1.944623 -1.2138557 1.8295264]
x = Traced<ConcreteArray([-0.32831764 -1.4084363 1.9748716 0.42638254 -0.01651764])>with<JVPTrace(level=2/0)> with primal = Traced<ConcreteArray([-0.32831764 -1.4084363 1.9748716 0.42638254 -0.01651764]):JaxprTrace(level=1/0)>
tangent = Traced<ShapedArray(float32[5]):JaxprTrace(level=1/0)>

Our questions are:

  1. what are the JaxprTrace levels? what are (level=1/0), (level=2/0)…?
    we understand x = [ 0.3011222 0.7353368 1.944623 -1.2138557 1.8295264],
    but what are the other values [-0.32831764 -1.4084363 1.9748716 0.42638254 -0.01651764] ?

  2. how do we store an intermediate value for example k in the piece of code above?

to store intermediate values please take a look at deterministic.
for an example see here.

1 Like

Thank you @martinjankowiak ! this is great! We can now store the intermediate values.

Just wondering: what are the JaxprTrace levels? what are (level=1/0), (level=2/0)…?
They appear when we just print x.

for that i refer you to jax documentation e.g. this or this

1 Like

Great thank you @martinjankowiak, this is helpful.