I am trying to use a Pyro plate to define a number of conditionally independent latent parameter distributions, but am having trouble using/accessing these parameters in inference. The model has a set of N “spectral line” components, each of which is represented by a Gaussian with unknown amplitude (mean and variance are known in this simple example). The priors on each of the (log) amplitudes are the same, so to me it makes sense to define these parameters using a plate. I’ve created a simplified minimal reproducible example here, but the relevant part of the model is:
class Spectrum(PyroModule): def __init__(self, mus): super().__init__() self.mus = torch.as_tensor(mus) self.nlines = len(self.mus) with pyro.plate("plate", self.nlines): self.log_amplitudes = pyro.sample("log_amplitudes", dist.Normal(1.0, 0.2)) self.baseline = PyroSample(dist.Normal(0.0, 1.0)) def intensities(self, x): I = torch.zeros_like(x) for i in range(self.nlines): A_i = torch.pow(10.0, self.log_amplitudes[i]) mu_i = self.mus[i] I += gaussian(x, A_i, mu_i) return I def forward(self, x, y, yerr): I = self.intensities(x) + self.baseline with pyro.plate("data", len(y)): pyro.sample("obs", dist.Normal(I, yerr), obs=y) return I
If I create a guide and/or run
SVI, none of the
self.log_amplitude_parameters appear to be in the guide or sample outputs. I’m confused as to what happens to them.
I know that I could define
log_amplitudes without a plate using the line,
self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2).expand([self.nlines]).to_event(1))
as described in the tensor shapes example, and this does work when in inference loops. However, I’d like to understand the plate feature of Pyro, since it seems central to building larger, more complex models. Most examples I have seen in the Pyro documentation use plates for observed random variables (and I’ve been able to successfully implement those), but I have also seen several examples with discrete latent variables where plates have been used for latent (unobserved) variables.
Can anyone please help me understand how I should use plates with continuous latent variables, and how then I may use them in an optimization/inference loop?
Finally, as a side note, I get a tensor shape error if I do
with pyro.plate("plate", self.nlines): self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2))
where I used
PyroSample instead of
pyro.sample(...) as in my example above.