Using plates with latent variables

I am trying to use a Pyro plate to define a number of conditionally independent latent parameter distributions, but am having trouble using/accessing these parameters in inference. The model has a set of N “spectral line” components, each of which is represented by a Gaussian with unknown amplitude (mean and variance are known in this simple example). The priors on each of the (log) amplitudes are the same, so to me it makes sense to define these parameters using a plate. I’ve created a simplified minimal reproducible example here, but the relevant part of the model is:

class Spectrum(PyroModule):
    def __init__(self, mus):
        super().__init__()

        self.mus = torch.as_tensor(mus)
        self.nlines = len(self.mus)
        
        with pyro.plate("plate", self.nlines):
            self.log_amplitudes = pyro.sample("log_amplitudes", dist.Normal(1.0, 0.2))

        self.baseline = PyroSample(dist.Normal(0.0, 1.0))

    def intensities(self, x):
        I = torch.zeros_like(x)
        for i in range(self.nlines):
            A_i = torch.pow(10.0, self.log_amplitudes[i])
            mu_i = self.mus[i]
            I += gaussian(x, A_i, mu_i)

        return I

    def forward(self, x, y, yerr):
        I = self.intensities(x) + self.baseline
        with pyro.plate("data", len(y)):
            pyro.sample("obs", dist.Normal(I, yerr), obs=y)

        return I

If I create a guide and/or run SVI, none of the self.log_amplitude_parameters appear to be in the guide or sample outputs. I’m confused as to what happens to them.

I know that I could define log_amplitudes without a plate using the line,

self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2).expand([self.nlines]).to_event(1))

as described in the tensor shapes example, and this does work when in inference loops. However, I’d like to understand the plate feature of Pyro, since it seems central to building larger, more complex models. Most examples I have seen in the Pyro documentation use plates for observed random variables (and I’ve been able to successfully implement those), but I have also seen several examples with discrete latent variables where plates have been used for latent (unobserved) variables.

Can anyone please help me understand how I should use plates with continuous latent variables, and how then I may use them in an optimization/inference loop?

Finally, as a side note, I get a tensor shape error if I do

with pyro.plate("plate", self.nlines):
    self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2))

where I used PyroSample instead of pyro.sample(...) as in my example above.

Thank you.

I rewrote this example following the def model(...) syntax rather than using PyroModule. The full example is here but the relevant code is now

def model_func(x, y, yerr):

    baseline = pyro.sample("baseline", dist.Normal(0.0, 1.0))
    with pyro.plate("plate", 3):
        log_amplitudes = pyro.sample("log_amplitudes", dist.Normal(1.0, 0.2))

    I = torch.zeros_like(x)
    for i in range(3):
        A_i = torch.pow(10.0, log_amplitudes[i])
        mu_i = true_mus[i]
        I += gaussian(x, A_i, mu_i)

    I += baseline
    with pyro.plate("data", len(y)):
        pyro.sample("obs", dist.Normal(I, yerr), obs=y)

This example seems to work as expected and I see the log_amplitudes variables show up in the inference routines. Was I misunderstanding something about how plates work in PyroModule, or have I potentially discovered a bug?

Hi. pyro.sample works a bit differently from PyroSample. pyro.sample needs to be called by the .forward method, not by the __init__. In fact, when you assign PyroSample to an attribute in the __init__ method it doesn’t do any sampling. It is only sampled when you access that attribute (to which you assigned PyroSample) in the .forward method (under the hood it calls pyro.sample). Hope this clarifies it a little.

Explained here in more detail.

Thanks! Your comment really put me on the right direction towards understanding how plates interact with PyroModules. I revised the relevant code to

class Spectrum(PyroModule):
    def __init__(self, mus):
        super().__init__()

        self.mus = torch.as_tensor(mus)
        self.nlines = len(self.mus)

        # will be made multivariate when it comes to the plate 
        self.log_amplitudes = PyroSample(dist.Normal(1.0, 0.2))
        
        self.baseline = PyroSample(dist.Normal(0.0, 1.0))

    @pyro_method
    def intensities(self, x):
        I = torch.zeros_like(x)
        for i in range(self.nlines):
            A_i = torch.pow(10.0, self.log_amplitudes[i])
            mu_i = self.mus[i]
            I += gaussian(x, A_i, mu_i)

        return I

    def forward(self, x, y, yerr):

        with pyro.plate("plate", self.nlines):    
            print("log amp values", self.log_amplitudes)
            I = self.intensities(x) 

        I += self.baseline

        with pyro.plate("data", len(y)):
            pyro.sample("obs", dist.Normal(I, yerr), obs=y)

        return I

And things seem to work so far! self.log_amplitudes is a length-self.nlines vector when it needs to be. This bit of the documentation also made a lot more sense after your comment.

Thanks for the help!