Mixture density network - Enumeration with TraceEnum_ELBO

I am trying to implement a Mixture Density Network with pyro, based on this pytorch implementation: pytorch_notebooks/mixture_density_networks.ipynb at master · hardmaru/pytorch_notebooks · GitHub

Now, I implemented my model and the guide and I am trying to run enumeration over the discrete mixtures. I read the tutorial on enumeration (Inference with Discrete Latent Variables — Pyro Tutorials 1.8.6 documentation) and on tensor shapes (Tensor shapes in Pyro — Pyro Tutorials 1.8.6 documentation), as well as the GMM tutorial (Gaussian Mixture Model — Pyro Tutorials 1.8.6 documentation) , but I can’t figure it out.

The difference to the GMM tutorial and the example in the enumerate tutorial is that here, since I have a MDN, the parameters for the mixture of Gaussians are different for each data sample. So, another way to formulate my question: How do I extend the simple code snippet in the enumerate tutorial (under “Plates and enumeration”) to have different mean for each data sample.
At the end, I post my code. When I run the code, I get the following error:

ValueError: Error while packing tensors at site 'assignment':
  Invalid tensor shape.
  Allowed dims: -2
  Actual shape: (5, 10)
  Try adding shape assertions for your model's sample values and distribution parameters.
       Trace Shapes:        
        Param Sites:        
  MDN$$$z_h.0.weight   20  1
    MDN$$$z_h.0.bias      20
   MDN$$$z_pi.weight    5 20
     MDN$$$z_pi.bias       5
MDN$$$z_sigma.weight    5 20
  MDN$$$z_sigma.bias       5
   MDN$$$z_mu.weight    5 20
     MDN$$$z_mu.bias       5
       Sample Sites:        
     assignment dist   10  |
               value 5  1  |
            log_prob 5 10  |

When I run the following code with my defined model:

trace = poutine.trace(poutine.enum(network.model,
first_available_dim=-2)).get_trace(x_variable, y_variable)
trace.compute_log_prob() # optional, but allows printing of log_prob shapes
print(trace.format_shapes())

I get the following output:

  Trace Shapes:        
        Param Sites:        
  MDN$$$z_h.0.weight   20  1
    MDN$$$z_h.0.bias      20
   MDN$$$z_pi.weight    5 20
     MDN$$$z_pi.bias       5
MDN$$$z_sigma.weight    5 20
  MDN$$$z_sigma.bias       5
   MDN$$$z_mu.weight    5 20
     MDN$$$z_mu.bias       5
       Sample Sites:        
     assignment dist   10  |
               value 5  1  |
            log_prob 5 10  |
            obs dist 5 10  |
               value 1 10  |
            log_prob 5 10  |

Which seems fine to me. Since this finishes without error, I am assuming, there is some problem with my guide?
What am I doing wrong here?

Here is my code:

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import pyro
import pyro.distributions as dist
from pyro.contrib.autoguide import AutoDelta
from pyro import poutine
from pyro.infer import SVI
from pyro.infer import TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam

torch.manual_seed(123)
np.random.seed(4530)

def generate_data(n_samples):
epsilon = np.random.normal(size=(n_samples))
x_data = np.random.uniform(-10.5, 10.5, n_samples)
y_data = 7np.sin(0.75x_data) + 0.5*x_data + epsilon
return x_data, y_data

n_samples = 10
x_data, y_data = generate_data(n_samples)

x_tensor = torch.from_numpy(np.float32(x_data).reshape(n_samples))
y_tensor = torch.from_numpy(np.float32(y_data).reshape(n_samples))
x_variable = Variable(x_tensor)
y_variable = Variable(y_tensor, requires_grad=False)

class MDN(nn.Module):
def init(self, n_hidden, n_gaussians):
super().init()
self.z_h = nn.Sequential(
nn.Linear(1, n_hidden),
nn.Tanh()
)
self.z_pi = nn.Linear(n_hidden, n_gaussians)
self.z_sigma = nn.Linear(n_hidden, n_gaussians)
self.z_mu = nn.Linear(n_hidden, n_gaussians)

def forward(self, data):
    z_h = self.z_h(data.view(-1, 1))
    pi = nn.functional.softmax(self.z_pi(z_h), -1)
    sigma = torch.exp(self.z_sigma(z_h))
    mu = self.z_mu(z_h)
    return pi, sigma, mu
@config_enumerate
def model(self, x, y):
    pyro.module("MDN", self)
    pi, sigma, mu = self.forward(y)
    muT = torch.transpose(mu, 0, 1)
    sigmaT = torch.transpose(sigma, 0, 1)
    assignment = pyro.sample("assignment", dist.Categorical(pi))
    pyro.sample('obs', dist.Normal(muT[assignment][:, 0],
                                   sigmaT[assignment][:, 0]),
                obs=x[None, :])

network = MDN(n_hidden=20, n_gaussians=5)
adam_params = {“lr”: 0.001, “betas”: (0.9, 0.999)}
optimizer = Adam(adam_params)

y_variable = Variable(y_tensor)
x_variable = Variable(x_tensor, requires_grad=False)

guide = AutoDelta(poutine.block(network.model, hide=[‘assignment’, ‘obs’]))
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(network.model, guide, optimizer, loss=elbo)

def train_mdn():
for epoch in range(1000):
loss = svi.step(x_variable, y_variable)
if epoch % 500 == 0:
print(epoch, loss)

train_mdn()

Well first I do recommend

Try adding shape assertions for your model’s sample values and distribution parameters.

That will help you and help readers like me understand your code :wink: From the error message

Invalid tensor shape.
Allowed dims: -2
Actual shape: (5, 10)

it looks like Pyro knows about dimension -2 of size 5, but doesn’t know about your dimension -1 of size 10 (note that we always count dimensions from the right so as to allow broadcasting). To fix, I think you simply need to add a pyro.plate to tell Pyro that you are vectorizing over samples, something like

@config_enumerate
def model(self, x, y):
    pyro.module("MDN", self)

    pi, sigma, mu = self.forward(y)
    muT = torch.transpose(mu, 0, 1)
    sigmaT = torch.transpose(sigma, 0, 1)
    assert muT.shape == (n_gaussians, 1)  # is this right?
    assert sigmaT.shape == (n_gaussians, 1)  # is this right?

    with pyro.plate("samples", n_samples):
        assignment = pyro.sample("assignment", dist.Categorical(pi))
        pyro.sample('obs', dist.Normal(muT[assignment][:, 0],
                                       sigmaT[assignment][:, 0]),
                    obs=x[None, :])

Let me know how this works!

Thank you for your fast reply!
I inserted the pyro.plate and now the code runs (I had tried variants of plate but don’t remember which exactly).

However, I need to make a distinction between two cases:

  1. The case where we actually sample the assignments and
  2. the case where we enumerate assignments.

This is like in the tutorial (“Plates and enumeration”, Inference with Discrete Latent Variables — Pyro Tutorials 1.8.4 documentation). Btw, I don’t fully understand why the first run occurs (“Observe that the model is run twice, first by the AutoDiagonalNormal to trace sample sites…”)

So, my code looks like this:

  @config_enumerate
def model(self, x, y):
    pyro.module("MDN", self)

    pi, sigma, mu = self.forward(y)
    muT = torch.transpose(mu, 0, 1)
    sigmaT = torch.transpose(sigma, 0, 1)

    assert muT.shape == (n_gaussians, n_samples)
    assert sigmaT.shape == (n_gaussians, n_samples)
    with pyro.plate("sample", n_samples):
        assignment = pyro.sample("assignment", dist.Categorical(pi))
        if len(assignment.shape) == 1:
            pyro.sample('obs', dist.Normal(torch.gather(muT, 0, assignment.view(1, -1))[0],
                                         torch.gather(sigmaT, 0, assignment.view(1, -1))[0]),
                        obs=x)
        else:
            pyro.sample('obs', dist.Normal(muT[assignment][:, 0],
                                           sigmaT[assignment][:, 0]),
                        obs=x)

A related question: To compute the log_prob of the sample, each of the Gaussian mixtures that I am enumerating, should be waited by the mixture weight, I assume. But how does pyro achieve this, because for the log_prob of a Gaussian mixture, each mixture has to be weighted first and then logarithmized. Is this what’s happening under the hood?

I need to make a distinction between two cases

Hmm I’m sure there’s a way to use advanced indexing to do that without a conditional, but advanced indexing is so unintuitive that I can’t help you figure out that way :weary:

Btw, I don’t fully understand why the first run occurs (“Observe that the model is run twice, first by the AutoDiagonalNormal to trace sample sites…”)

This is a little tricky. AutoDiagonalNormal needs to inspect your model to know what variables (and what shapes and constraints) it should model. In theory we could do static code analysis, but to be maximally flexible AutoDiagonalNormal instead simply runs the model and records variables and shapes and constraints. But AutoDiagonalNormal doesn’t enumerate anything (it simply ignores the infer={"enumerate": "parallel"} annotations created by @config_enumerate.

…weighted first and then logarithmized. Is this what’s happening under the hood?

Can you clarify your question? I think you’re using “sample” in two different senses. You might take a look at the unit tests to convince yourself Pyro is doing what you intend.

Thanks for the clarification of my first question. The documentation of iter_discrete_traces answers my second question.

I agree, the case distinction is a little ugly, I have been trying different things for indexing, but couldn’t figure out a smarter way.

Thank you for your help, my code runs fine now.

I now published my implementation on github: GitHub - ascentai/mdn_pyro: Implementation of a Mixture Density Network in the deep probabilistic programming language Pyro.

3 Likes