[BUG] Side effect when using jax.Array attributes with random_nnx_module()

I built an nnx model that I wanted to turn into a BNN using NumPyro. Problems arose for my max-min normalisation when implemented like the code below:

def MyModel(nnx.Module):
  def __init__(self, ...)
    ...
    self.maxima = jnp.array([max_min_tuple[0] for max_min_tuple in max_min_list])
    self.minima = jnp.array([max_min_tuple[1] for max_min_tuple in max_min_list])
    ...

  def __call__(self, x):
    maxima_array = jnp.broadcast_to(self.maxima, x.shape)
    minima_array = jnp.broadcast_to(self.minima, x.shape)
    x = (x - minima_array) / (maxima_array - minima_array)
    ...

def numpyro_model(x, y):
  flax_nn = random_nnx_module('flax_nn', flax_model, prior=dist.Normal())
  nn_res = numpyro.deterministic("nn_res", flax_nn(x))
  ...

This produced

jax.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[8] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was _body_fn at /var/data/python/lib/python3.11/site-packages/numpyro/infer/hmc_util.py:1003 traced for while_body.

But when I changed my code to

def MyModel(nnx.Module):
  def __init__(self, ...)
    ...
    self.maxima = [max_min_tuple[0] for max_min_tuple in max_min_list]
    self.minima = [max_min_tuple[1] for max_min_tuple in max_min_list]
    ...

  def __call__(self, x):
    maxima_array = jnp.broadcast_to(jnp.array(self.maxima), x.shape)
    minima_array = jnp.broadcast_to(jnp.array(self.minima), x.shape)
    x = (x - minima_array) / (maxima_array - minima_array)
    ...

def numpyro_model(x, y):
  flax_nn = random_nnx_module('flax_nn', flax_model, prior=dist.Normal())
  nn_res = numpyro.deterministic("nn_res", flax_nn(x))
  ...

everything ran fine.

If a constant Python list does not constitute a side-effect, then a constant jax.Array should also not constitute a side-effect. I do not have the expertise required to fix random_nnx_module(), but I did want to publicly note that something is going wrong here.

Does nnx allow to put jax array as an attribute? I’m not familiar with the api but it seems you need to use Variable like in NNX Basics