Hello! we are trying to understand the trace and store an intermediate result:
For example, we have this little piece of code:
def model():
x = numpyro.sample("x", dist.Normal(0,1).expand([5]))
print('x =', x)
k = x+1
mcmc = MCMC(NUTS(model), 1000, 1000)
mcmc.run(random.PRNGKey(0))
We are looking at the print: it displays the array x, but it also displays other JVPTrace with (level=2/0), …
x = [ 0.3011222 0.7353368 1.944623 -1.2138557 1.8295264]
x = Traced<ConcreteArray([-0.32831764 -1.4084363 1.9748716 0.42638254 -0.01651764])>with<JVPTrace(level=2/0)> with primal = Traced<ConcreteArray([-0.32831764 -1.4084363 1.9748716 0.42638254 -0.01651764]):JaxprTrace(level=1/0)>
tangent = Traced<ShapedArray(float32[5]):JaxprTrace(level=1/0)>
…
Our questions are:
-
what are the JaxprTrace levels? what are (level=1/0), (level=2/0)…?
we understand x = [ 0.3011222 0.7353368 1.944623 -1.2138557 1.8295264],
but what are the other values [-0.32831764 -1.4084363 1.9748716 0.42638254 -0.01651764] ? -
how do we store an intermediate value for example k in the piece of code above?