How to access the output of a neural net in the guide?

I’m working to implement a simple amortized inference model for gmms. I’m wanting to record the output of the neural net in my guide, but can’t seem to get it working. Right now I’m just creating a pyro.param(‘phi’, phi) object which I can reference later on using pyro.param(‘phi’).

T = 5  # Fixed number of components.

class MLP(nn.Module):
  '''
    Outputs a probability vector of length T
  '''
  def __init__(self):
    super().__init__()
    hidden_layer_1_size = 2*N
    hidden_layer_2_size = 2*N
    input_size = N
    output_size = T
    self.layers = nn.Sequential(
      nn.Linear(input_size, hidden_layer_1_size),
      nn.Linear(hidden_layer_1_size, hidden_layer_2_size),
      nn.ReLU(),
      nn.Linear(hidden_layer_2_size, output_size),
      nn.Softmax(0)
    )

  def forward(self, x):
    return self.layers(x)


def model(data):
    with pyro.plate('components', T):
        locs = pyro.sample('locs', Normal(0, 1))

    with pyro.plate('data', N):
        # Local variables.
        assignments = pyro.sample('assignments', Categorical(torch.ones(T) / T)) # returns a vector of length T
        obs = pyro.sample('obs', Normal(locs[assignments], 1), obs=data)

def guide(data):
    # amortize using MLP
    pyro.module('mlp', mlp)
    
    # sample mixture components mu
    tau = pyro.param('tau', lambda: Normal(0, 1).sample([T]))
    with pyro.plate('components', T) as i:      
        pyro.sample('locs', Normal(tau[i], 1))
    
    # sample cluster assignments
    phi = mlp(data.double()) # returns a vector of length T
    pyro.param('phi', phi)    
    with pyro.plate("data", N):
        pyro.sample("assignments", Categorical(phi)) # returns a vector of length N

However, this gives me the following error when I try to train. How would I go about tracking phi?

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-6-bc0042dec5b6> in <module>
      7 start = time.time()
      8 for step in range(n_steps):
----> 9   svi.step(data)
     10   pyro.get_param_store()
     11   if step % 100 == 0:

~\miniconda3\envs\cpsc532w\lib\site-packages\pyro\infer\svi.py in step(self, *args, **kwargs)
    151         # actually perform gradient steps
    152         # torch.optim objects gets instantiated for any params that haven't been seen yet
--> 153         self.optim(params)
    154 
    155         # zero gradients

~\miniconda3\envs\cpsc532w\lib\site-packages\pyro\optim\optim.py in __call__(self, params, *args, **kwargs)
     88             if p not in self.optim_objs:
     89                 # create a single optim object for that param
---> 90                 self.optim_objs[p] = self._get_optim(p)
     91                 # create a gradient clipping function if specified
     92                 self.grad_clip[p] = self._get_grad_clip(p)

~\miniconda3\envs\cpsc532w\lib\site-packages\pyro\optim\optim.py in _get_optim(self, param)
    150 
    151     def _get_optim(self, param: Union[Iterable[Tensor], Iterable[Dict[Any, Any]]]):
--> 152         return self.pt_optim_constructor([param], **self._get_optim_args(param))  # type: ignore
    153 
    154     # helper to fetch the optim args if callable (only used internally)

~\miniconda3\envs\cpsc532w\lib\site-packages\torch\optim\adam.py in __init__(self, params, lr, betas, eps, weight_decay, amsgrad)
     46         defaults = dict(lr=lr, betas=betas, eps=eps,
     47                         weight_decay=weight_decay, amsgrad=amsgrad)
---> 48         super(Adam, self).__init__(params, defaults)
     49 
     50     def __setstate__(self, state):

~\miniconda3\envs\cpsc532w\lib\site-packages\torch\optim\optimizer.py in __init__(self, params, defaults)
     52 
     53         for param_group in param_groups:
---> 54             self.add_param_group(param_group)
     55 
     56     def __getstate__(self):

~\miniconda3\envs\cpsc532w\lib\site-packages\torch\optim\optimizer.py in add_param_group(self, param_group)
    255                                 "but one of the params is " + torch.typename(param))
    256             if not param.is_leaf:
--> 257                 raise ValueError("can't optimize a non-leaf Tensor")
    258 
    259         for name, default in self.defaults.items():

ValueError: can't optimize a non-leaf Tensor