KeyError with svi.init

I have a model – below joint_model – that I’m able to run with NUTS. Now I want to try using variational inference to get (hopefully) a faster approximate answer. But I’m running into this issue

>>> guide = AutoNormal(joint_model)
... svi = SVI(joint_model, guide, optimizer, loss=Trace_ELBO())
... init_state = svi.init(random.key(0))
... 
Traceback (most recent call last):
  File "<python-input-20>", line 3, in <module>
    init_state = svi.init(random.key(0))
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/infer/svi.py", line 184, in init
    guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/handlers.py", line 191, in get_trace
    self(*args, **kwargs)
    ~~~~^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 121, in __call__
    return self.fn(*args, **kwargs)
           ~~~~~~~^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/handlers.py", line 846, in __call__
    return cloned_seeded_fn.__call__(*args, **kwargs)
           ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/handlers.py", line 847, in __call__
    return super().__call__(*args, **kwargs)
           ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/primitives.py", line 121, in __call__
    return self.fn(*args, **kwargs)
           ~~~~~~~^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/infer/autoguide.py", line 411, in __call__
    self._setup_prototype(*args, **kwargs)
    ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "[...]/.venv/lib/python3.13/site-packages/numpyro/infer/autoguide.py", line 401, in _setup_prototype
    full_size = self._prototype_frame_full_sizes[frame.name]
                ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'std_normals_plate'

Now my joint_model is a composition of several modules. The one that has the std_normals_plate is this

def CAR_layer_units(base_shape, adj_matrix, rho_prior = dist.Beta(1000.,.9)):
    rho = numpyro.sample('rho', rho_prior)
    
    with numpyro.plate("std_normals_plate", adj_matrix.shape[0], dim=-2):
        zs = numpyro.sample('units', dist.Normal())
    
    # transform iid normals -> correlated normals -> correlated U[0,1] variables
    us = transform_iid_normal_layers(base_shape, adj_matrix, rho, zs)
    return us

The weird thing is that this succeeds

>>> guide = AutoNormal(CAR_layer_units)
... svi = SVI(CAR_layer_units, guide, optimizer, loss=Trace_ELBO())
... init_state = svi.init(random.key(0), base_shape, adj_matrix)
... 
>>> 

Sadly it’s going to be impossible to share the whole model here. I’m hoping someone can give me a hint of what may be going wrong here just by looking at the error trace.

Cheers