I am trying to implement a Mixture Density Network with pyro, based on this pytorch implementation: pytorch_notebooks/mixture_density_networks.ipynb at master · hardmaru/pytorch_notebooks · GitHub
Now, I implemented my model and the guide and I am trying to run enumeration over the discrete mixtures. I read the tutorial on enumeration (Inference with Discrete Latent Variables — Pyro Tutorials 1.8.6 documentation) and on tensor shapes (Tensor shapes in Pyro — Pyro Tutorials 1.8.6 documentation), as well as the GMM tutorial (Gaussian Mixture Model — Pyro Tutorials 1.8.6 documentation) , but I can’t figure it out.
The difference to the GMM tutorial and the example in the enumerate tutorial is that here, since I have a MDN, the parameters for the mixture of Gaussians are different for each data sample. So, another way to formulate my question: How do I extend the simple code snippet in the enumerate tutorial (under “Plates and enumeration”) to have different mean for each data sample.
At the end, I post my code. When I run the code, I get the following error:
ValueError: Error while packing tensors at site 'assignment':
Invalid tensor shape.
Allowed dims: -2
Actual shape: (5, 10)
Try adding shape assertions for your model's sample values and distribution parameters.
Trace Shapes:
Param Sites:
MDN$$$z_h.0.weight 20 1
MDN$$$z_h.0.bias 20
MDN$$$z_pi.weight 5 20
MDN$$$z_pi.bias 5
MDN$$$z_sigma.weight 5 20
MDN$$$z_sigma.bias 5
MDN$$$z_mu.weight 5 20
MDN$$$z_mu.bias 5
Sample Sites:
assignment dist 10 |
value 5 1 |
log_prob 5 10 |
When I run the following code with my defined model:
trace = poutine.trace(poutine.enum(network.model,
first_available_dim=-2)).get_trace(x_variable, y_variable)
trace.compute_log_prob() # optional, but allows printing of log_prob shapes
print(trace.format_shapes())
I get the following output:
Trace Shapes:
Param Sites:
MDN$$$z_h.0.weight 20 1
MDN$$$z_h.0.bias 20
MDN$$$z_pi.weight 5 20
MDN$$$z_pi.bias 5
MDN$$$z_sigma.weight 5 20
MDN$$$z_sigma.bias 5
MDN$$$z_mu.weight 5 20
MDN$$$z_mu.bias 5
Sample Sites:
assignment dist 10 |
value 5 1 |
log_prob 5 10 |
obs dist 5 10 |
value 1 10 |
log_prob 5 10 |
Which seems fine to me. Since this finishes without error, I am assuming, there is some problem with my guide?
What am I doing wrong here?
Here is my code:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import pyro
import pyro.distributions as dist
from pyro.contrib.autoguide import AutoDelta
from pyro import poutine
from pyro.infer import SVI
from pyro.infer import TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam
torch.manual_seed(123)
np.random.seed(4530)
def generate_data(n_samples):
epsilon = np.random.normal(size=(n_samples))
x_data = np.random.uniform(-10.5, 10.5, n_samples)
y_data = 7np.sin(0.75x_data) + 0.5*x_data + epsilon
return x_data, y_data
n_samples = 10
x_data, y_data = generate_data(n_samples)
x_tensor = torch.from_numpy(np.float32(x_data).reshape(n_samples))
y_tensor = torch.from_numpy(np.float32(y_data).reshape(n_samples))
x_variable = Variable(x_tensor)
y_variable = Variable(y_tensor, requires_grad=False)
class MDN(nn.Module):
def init(self, n_hidden, n_gaussians):
super().init()
self.z_h = nn.Sequential(
nn.Linear(1, n_hidden),
nn.Tanh()
)
self.z_pi = nn.Linear(n_hidden, n_gaussians)
self.z_sigma = nn.Linear(n_hidden, n_gaussians)
self.z_mu = nn.Linear(n_hidden, n_gaussians)
def forward(self, data): z_h = self.z_h(data.view(-1, 1)) pi = nn.functional.softmax(self.z_pi(z_h), -1) sigma = torch.exp(self.z_sigma(z_h)) mu = self.z_mu(z_h) return pi, sigma, mu
@config_enumerate def model(self, x, y): pyro.module("MDN", self)
pi, sigma, mu = self.forward(y) muT = torch.transpose(mu, 0, 1) sigmaT = torch.transpose(sigma, 0, 1) assignment = pyro.sample("assignment", dist.Categorical(pi)) pyro.sample('obs', dist.Normal(muT[assignment][:, 0], sigmaT[assignment][:, 0]), obs=x[None, :])
network = MDN(n_hidden=20, n_gaussians=5)
adam_params = {“lr”: 0.001, “betas”: (0.9, 0.999)}
optimizer = Adam(adam_params)
y_variable = Variable(y_tensor)
x_variable = Variable(x_tensor, requires_grad=False)
guide = AutoDelta(poutine.block(network.model, hide=[‘assignment’, ‘obs’]))
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(network.model, guide, optimizer, loss=elbo)
def train_mdn():
for epoch in range(1000):
loss = svi.step(x_variable, y_variable)
if epoch % 500 == 0:
print(epoch, loss)
train_mdn()