I’m working to implement a simple amortized inference model for gmms. I’m wanting to record the output of the neural net in my guide, but can’t seem to get it working. Right now I’m just creating a pyro.param(‘phi’, phi) object which I can reference later on using pyro.param(‘phi’).
T = 5 # Fixed number of components.
class MLP(nn.Module):
'''
Outputs a probability vector of length T
'''
def __init__(self):
super().__init__()
hidden_layer_1_size = 2*N
hidden_layer_2_size = 2*N
input_size = N
output_size = T
self.layers = nn.Sequential(
nn.Linear(input_size, hidden_layer_1_size),
nn.Linear(hidden_layer_1_size, hidden_layer_2_size),
nn.ReLU(),
nn.Linear(hidden_layer_2_size, output_size),
nn.Softmax(0)
)
def forward(self, x):
return self.layers(x)
def model(data):
with pyro.plate('components', T):
locs = pyro.sample('locs', Normal(0, 1))
with pyro.plate('data', N):
# Local variables.
assignments = pyro.sample('assignments', Categorical(torch.ones(T) / T)) # returns a vector of length T
obs = pyro.sample('obs', Normal(locs[assignments], 1), obs=data)
def guide(data):
# amortize using MLP
pyro.module('mlp', mlp)
# sample mixture components mu
tau = pyro.param('tau', lambda: Normal(0, 1).sample([T]))
with pyro.plate('components', T) as i:
pyro.sample('locs', Normal(tau[i], 1))
# sample cluster assignments
phi = mlp(data.double()) # returns a vector of length T
pyro.param('phi', phi)
with pyro.plate("data", N):
pyro.sample("assignments", Categorical(phi)) # returns a vector of length N
However, this gives me the following error when I try to train. How would I go about tracking phi?
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-6-bc0042dec5b6> in <module>
7 start = time.time()
8 for step in range(n_steps):
----> 9 svi.step(data)
10 pyro.get_param_store()
11 if step % 100 == 0:
~\miniconda3\envs\cpsc532w\lib\site-packages\pyro\infer\svi.py in step(self, *args, **kwargs)
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet
--> 153 self.optim(params)
154
155 # zero gradients
~\miniconda3\envs\cpsc532w\lib\site-packages\pyro\optim\optim.py in __call__(self, params, *args, **kwargs)
88 if p not in self.optim_objs:
89 # create a single optim object for that param
---> 90 self.optim_objs[p] = self._get_optim(p)
91 # create a gradient clipping function if specified
92 self.grad_clip[p] = self._get_grad_clip(p)
~\miniconda3\envs\cpsc532w\lib\site-packages\pyro\optim\optim.py in _get_optim(self, param)
150
151 def _get_optim(self, param: Union[Iterable[Tensor], Iterable[Dict[Any, Any]]]):
--> 152 return self.pt_optim_constructor([param], **self._get_optim_args(param)) # type: ignore
153
154 # helper to fetch the optim args if callable (only used internally)
~\miniconda3\envs\cpsc532w\lib\site-packages\torch\optim\adam.py in __init__(self, params, lr, betas, eps, weight_decay, amsgrad)
46 defaults = dict(lr=lr, betas=betas, eps=eps,
47 weight_decay=weight_decay, amsgrad=amsgrad)
---> 48 super(Adam, self).__init__(params, defaults)
49
50 def __setstate__(self, state):
~\miniconda3\envs\cpsc532w\lib\site-packages\torch\optim\optimizer.py in __init__(self, params, defaults)
52
53 for param_group in param_groups:
---> 54 self.add_param_group(param_group)
55
56 def __getstate__(self):
~\miniconda3\envs\cpsc532w\lib\site-packages\torch\optim\optimizer.py in add_param_group(self, param_group)
255 "but one of the params is " + torch.typename(param))
256 if not param.is_leaf:
--> 257 raise ValueError("can't optimize a non-leaf Tensor")
258
259 for name, default in self.defaults.items():
ValueError: can't optimize a non-leaf Tensor