# Using plates with latent variables

I am trying to use a Pyro plate to define a number of conditionally independent latent parameter distributions, but am having trouble using/accessing these parameters in inference. The model has a set of N “spectral line” components, each of which is represented by a Gaussian with unknown amplitude (mean and variance are known in this simple example). The priors on each of the (log) amplitudes are the same, so to me it makes sense to define these parameters using a plate. I’ve created a simplified minimal reproducible example here, but the relevant part of the model is:

``````class Spectrum(PyroModule):
def __init__(self, mus):
super().__init__()

self.mus = torch.as_tensor(mus)
self.nlines = len(self.mus)

with pyro.plate("plate", self.nlines):
self.log_amplitudes = pyro.sample("log_amplitudes", dist.Normal(1.0, 0.2))

self.baseline = PyroSample(dist.Normal(0.0, 1.0))

def intensities(self, x):
I = torch.zeros_like(x)
for i in range(self.nlines):
A_i = torch.pow(10.0, self.log_amplitudes[i])
mu_i = self.mus[i]
I += gaussian(x, A_i, mu_i)

return I

def forward(self, x, y, yerr):
I = self.intensities(x) + self.baseline
with pyro.plate("data", len(y)):
pyro.sample("obs", dist.Normal(I, yerr), obs=y)

return I
``````

If I create a guide and/or run `SVI`, none of the `self.log_amplitude_parameters` appear to be in the guide or sample outputs. I’m confused as to what happens to them.

I know that I could define `log_amplitudes` without a plate using the line,

``````self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2).expand([self.nlines]).to_event(1))
``````

as described in the tensor shapes example, and this does work when in inference loops. However, I’d like to understand the plate feature of Pyro, since it seems central to building larger, more complex models. Most examples I have seen in the Pyro documentation use plates for observed random variables (and I’ve been able to successfully implement those), but I have also seen several examples with discrete latent variables where plates have been used for latent (unobserved) variables.

Can anyone please help me understand how I should use plates with continuous latent variables, and how then I may use them in an optimization/inference loop?

Finally, as a side note, I get a tensor shape error if I do

``````with pyro.plate("plate", self.nlines):
self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2))
``````

where I used `PyroSample` instead of `pyro.sample(...)` as in my example above.

Thank you.

I rewrote this example following the `def model(...)` syntax rather than using `PyroModule`. The full example is here but the relevant code is now

``````def model_func(x, y, yerr):

baseline = pyro.sample("baseline", dist.Normal(0.0, 1.0))
with pyro.plate("plate", 3):
log_amplitudes = pyro.sample("log_amplitudes", dist.Normal(1.0, 0.2))

I = torch.zeros_like(x)
for i in range(3):
A_i = torch.pow(10.0, log_amplitudes[i])
mu_i = true_mus[i]
I += gaussian(x, A_i, mu_i)

I += baseline
with pyro.plate("data", len(y)):
pyro.sample("obs", dist.Normal(I, yerr), obs=y)
``````

This example seems to work as expected and I see the `log_amplitudes` variables show up in the inference routines. Was I misunderstanding something about how plates work in `PyroModule`, or have I potentially discovered a bug?

Hi. `pyro.sample` works a bit differently from `PyroSample`. `pyro.sample` needs to be called by the `.forward` method, not by the `__init__`. In fact, when you assign `PyroSample` to an attribute in the `__init__` method it doesn’t do any sampling. It is only sampled when you access that attribute (to which you assigned `PyroSample`) in the `.forward` method (under the hood it calls `pyro.sample`). Hope this clarifies it a little.

Explained here in more detail.

Thanks! Your comment really put me on the right direction towards understanding how plates interact with `PyroModule`s. I revised the relevant code to

``````class Spectrum(PyroModule):
def __init__(self, mus):
super().__init__()

self.mus = torch.as_tensor(mus)
self.nlines = len(self.mus)

# will be made multivariate when it comes to the plate
self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2))

self.baseline = PyroSample(dist.Normal(0.0, 1.0))

@pyro_method
def intensities(self, x):
I = torch.zeros_like(x)
for i in range(self.nlines):
A_i = torch.pow(10.0, self.log_amplitudes[i])
mu_i = self.mus[i]
I += gaussian(x, A_i, mu_i)

return I

def forward(self, x, y, yerr):

with pyro.plate("plate", self.nlines):
print("log amp values", self.log_amplitudes)
I = self.intensities(x)

I += self.baseline

with pyro.plate("data", len(y)):
pyro.sample("obs", dist.Normal(I, yerr), obs=y)

return I
``````

And things seem to work so far! `self.log_amplitudes` is a length-`self.nlines` vector when it needs to be. This bit of the documentation also made a lot more sense after your comment.

Thanks for the help!