Vectorized Predictive distribution incorrectly computes site shape


I am using Pyro version 1.4.0. I am modelling a time series and I use SVI to also learn a inference network that outputs the distribution of latent variables at each time step (given only the input at that specific time step).

def guide(self, x):
        pyro.module('latent_dependency_vae', self)        
        b_s = x.shape[1]        
        with pyro.plate('batch_plate', b_s, dim=-1):
            mu_z, sigma_z = self.inference_net.x_to_z(x)
            z = torch.empty(self.T, self.num_samples_elbo, b_s, self.dim_z)   
            for t in range(self.T): 
                z[t] = pyro.sample('z_{}'.format(t), dist.Normal(mu_z[t], sigma_z[t]).to_event(1))   
            mu_w, sigma_w = self.inference_net.z_to_w(z)
            for t in range(self.T):
                w = pyro.sample('w_{}'.format(t), dist.Normal(mu_w[t], sigma_w[t]).to_event(1))

In my model, I have a simple Markovian prior on the latent variables:

def prior(self, b_s):
        z = torch.empty((self.T, self.num_samples_elbo, b_s, self.dim_z))
        w = torch.empty((self.T, self.num_samples_elbo, b_s, self.dim_w))  
        z[0] = pyro.sample('z_0', dist.Normal(torch.zeros(self.dim_z), 1e-3).to_event(1))
        w[0] = pyro.sample('w_0', dist.Normal(torch.zeros(self.dim_w), 1e-3).to_event(1))          
        for t in range(1, self.T):   
            z[t] = pyro.sample('z_{}'.format(t), dist.Normal(z[t-1], 1e-3).to_event(1))
            w[t] = pyro.sample('w_{}'.format(t), dist.Normal(w[t-1], 1e-3).to_event(1))
        return z, w

And these are used to generate the emission:

def model(self, x):
        pyro.module('latent_dependency_vae', self)        
        b_s = x.shape[1]
        with pyro.plate('batch_plate', b_s, dim=-1): 
            z, w = self.prior(b_s)     
            mu_x = self.inference_net.latents_to_x(z, w)
            for t in range(self.T):
                x_gen = pyro.sample('x_{}'.format(t), dist.Normal(mu_x[t], 1e-3).to_event(1), obs=x[t])

If I then try to use the Predictive class like so:

 dim_x = 10
 dim_z = 10
 dim_w = 10
 T = 5
 num_samples = 10  
 vae_model = LatentDependencyVAE(dim_x, dim_z, dim_w, T, num_samples)
 predictive = Predictive(poutine.uncondition(vae_model.model),,   num_samples=num_samples, parallel=True)
 minibatch_size = 1
 x = torch.randn(T, minibatch_size, dim_x)
 samples = predictive(x)

where LatentDependencyVAE is the wrapper containing prior, model and guide:

class LatentDependencyVAE(torch.nn.Module):    
    def __init__(self, dim_x, dim_z, dim_w, T, num_samples_elbo):
        self.dim_z = dim_z
        self.dim_w = dim_w
        self.T = T
        self.num_samples_elbo = num_samples_elbo
        self.inference_net = InferenceNet(dim_x, dim_z, dim_w)    

I run into the problem that


in _predictive() on line 70 takes w to incorrectly have a batch_shape of [10, 1] when it is actually [1] and the [10] is the number of samples of w which belongs under the vectorized plate.

Clearly this is a flaw in how I have written the model and the guide to explicitly include num_samples, but I am not sure how to write it otherwise. Sorry for the long read and thank you for any help!

Hi @MontyPython, two rules of thumb that will make your life easier when using Pyro are (1) always index starting from the right using ellipses, and assume the dimensions to the left of your leftmost plate are occupied and (2) use plates or the num_particles argument to ELBO to draw multiple samples from your model, and treat the resulting batch dimensions as a Pyro implementation detail. You can read more about this in the tensor shape tutorial.

To illustrate, I modified your prior so that it no longer makes an explicit assumption about the batch shape of its sample sites. You’ll need to make similar changes to the rest of your code.

def prior(self):  # b_s is not used here now
    zs, ws = [None] * self.T, [None] * self.T
    zs[0] = pyro.sample('z_0', dist.Normal(torch.zeros(self.dim_z), 1e-3).to_event(1))
    ws[0] = pyro.sample('w_0', dist.Normal(torch.zeros(self.dim_w), 1e-3).to_event(1))          
    for t in range(1, self.T):   
        zs[t] = pyro.sample('z_{}'.format(t), dist.Normal(zs[t-1], 1e-3).to_event(1))
        ws[t] = pyro.sample('w_{}'.format(t), dist.Normal(ws[t-1], 1e-3).to_event(1))
    # output two tensors of shape (self.num_samples_elbo, ..., b_s, self.dim_z/w, self.T)
    return torch.stack(zs, -1), torch.stack(ws, -1)

# in the model, you'll need to modify self.inference_net.latents_to_x
# to expect the time dimension to be the rightmost dimension of z and w
x_gen = pyro.sample('x_{}'.format(t), dist.Normal(mu_x[..., t], 1e-3).to_event(1), obs=x[t])
1 Like

That worked, thank you!