ValueError: at site "N_X", invalid log_prob shape Expected [], actual [4096]

I am working on Dsprites Dataset and using SVI inference to do counterfactual evaluations. Importance algorithm and MCMC didn’t work for us and so we decided to make this an optimization problem by using SVI instead. Here is how my VAE looks like:

class Encoder(nn.Module):
    def __init__(self, image_dim, label_dim, z_dim):
        super(Encoder, self).__init__()
        self.image_dim = image_dim
        self.label_dim = label_dim
        self.z_dim = z_dim
        # setup the three linear transformations used
        self.fc1 = nn.Linear(self.image_dim+self.label_dim, 1000)
        self.fc2 = nn.Linear(1000, 1000)
        self.fc31 = nn.Linear(1000, z_dim)  # mu values
        self.fc32 = nn.Linear(1000, z_dim)  # sigma values
        # setup the non-linearities
        self.softplus = nn.Softplus()

    def forward(self, xs, ys):
        # define the forward computation on the image xs and label ys
        # first shape the mini-batch to have pixels in the rightmost dimension
        xs = xs.reshape(-1, self.image_dim)
        #now concatenate the image and label
        inputs = torch.cat((xs,ys), -1)
        # then compute the hidden units
        hidden1 = self.softplus(self.fc1(inputs))
        hidden2 = self.softplus(self.fc2(hidden1))
        # then return a mean vector and a (positive) square root covariance
        # each of size batch_size x z_dim
        z_loc = self.fc31(hidden2)
        z_scale = torch.exp(self.fc32(hidden2))
        return z_loc, z_scale
    
class Decoder(nn.Module):
    def __init__(self, image_dim, label_dim, z_dim):
        super(Decoder, self).__init__()
        # setup the two linear transformations used
        hidden_dim = 1000
        self.fc1 = nn.Linear(z_dim+label_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, image_dim)
        # setup the non-linearities
        self.softplus = nn.Softplus()
        self.sigmoid = nn.Sigmoid()

    def forward(self, zs, ys):
        # define the forward computation on the latent z and label y
        # first concatenate z and y
        inputs = torch.cat((zs, ys),-1)
        # then compute the hidden units
        hidden1 = self.softplus(self.fc1(inputs))
        hidden2 = self.softplus(self.fc2(hidden1))
        hidden3 = self.softplus(self.fc3(hidden2))
        # return the parameter for the output Bernoulli
        # each is of size batch_size x 784
        loc_img = self.sigmoid(self.fc4(hidden3))
        return loc_img

class CVAE(nn.Module):

    def __init__(self, config_enum=None, use_cuda=False, aux_loss_multiplier=None):

        super(CVAE, self).__init__()
    
        self.image_dim = 64**2
        self.label_shape = np.array((1,3,6,40,32,32))
        self.label_names = np.array(('color', 'shape', 'scale', 'orientation', 'posX', 'posY'))
        self.label_dim = np.sum(self.label_shape)
        self.z_dim = 50                                    
        self.use_cuda = use_cuda

        # define and instantiate the neural networks representing
        # the paramters of various distributions in the model
        self.setup_networks()

    def setup_networks(self):
        self.encoder = Encoder(self.image_dim, self.label_dim, self.z_dim)

        self.decoder = Decoder(self.image_dim, self.label_dim, self.z_dim)

        # using GPUs for faster training of the networks
        if self.use_cuda:
            self.cuda()

    def model(self, xs, ys):
        """
        The model corresponds to the following generative process:
        p(z) = normal(0,I)              # dsprites label (latent)
        p(x|y,z) = bernoulli(loc(y,z))   # an image
        loc is given by a neural network  `decoder`

        :param xs: a batch of scaled vectors of pixels from an image
        :param ys: a batch of the class labels i.e.
                   the digit corresponding to the image(s)
        :return: None
        """
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("cvae", self)

        batch_size = xs.size(0)
        options = dict(dtype=xs.dtype, device=xs.device)
        with pyro.plate("data"):

            prior_loc = torch.zeros(batch_size, self.z_dim, **options)
            prior_scale = torch.ones(batch_size, self.z_dim, **options)
            zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))
            
            # if the label y (which digit to write) is supervised, sample from the
            # constant prior, otherwise, observe the value (i.e. score it against the constant prior)
    
            loc = self.decoder.forward(zs, self.remap_y(ys))
            pyro.sample("x", dist.Bernoulli(loc).to_event(1), obs=xs)
            # return the loc so we can visualize it later
            return loc

    def guide(self, xs, ys):
        """
        The guide corresponds to the following:
        q(z|x,y) = normal(loc(x,y),scale(x,y))       # infer latent class from an image and the label 
        loc, scale are given by a neural network `encoder`

        :param xs: a batch of scaled vectors of pixels from an image
        :return: None
        """
        # inform Pyro that the variables in the batch of xs are conditionally independent
        with pyro.plate("data"):
            # sample (and score) the latent handwriting-style with the variational
            # distribution q(z|x) = normal(loc(x),scale(x))
    
            loc, scale = self.encoder.forward(xs, self.remap_y(ys))
            pyro.sample("z", dist.Normal(loc, scale).to_event(1))
            
    def remap_y(self, ys):
        new_ys = []
        options = dict(dtype=ys.dtype, device=ys.device)
        for i, label_length in enumerate(self.label_shape):
            prior = torch.ones(ys.size(0), label_length, **options) / (1.0 * label_length)
            new_ys.append(pyro.sample("y_%s" % self.label_names[i], dist.OneHotCategorical(prior), 
                                   obs=torch.nn.functional.one_hot(ys[:,i].to(torch.int64), int(label_length))))
        new_ys = torch.cat(new_ys, -1)
        return new_ys.to(torch.float32)
            
    def reconstruct_image(self, xs, ys):
        # backward
        sim_z_loc, sim_z_scale = self.encoder.forward(xs, self.remap_y(ys))
        zs = dist.Normal(sim_z_loc, sim_z_scale).to_event(1).sample()
        # forward
        loc = self.decoder.forward(zs, self.remap_y(ys))
        return dist.Bernoulli(loc).to_event(1).sample()

This is how my Structural Causal Model looks like:

class SCM():
    def __init__(self, vae, mu, sigma):
        self.vae = vae
        self.image_dim = vae.image_dim
        self.z_dim = vae.z_dim
        
        mu = mu.cpu()
        sigma = sigma.cpu()
        
        # these are used for f_X
        self.label_dims = vae.label_shape
        
        def f_X(Y, Z, N):
            zs = Z.cuda()
            
            # convert the labels to one hot
            ys = [torch.tensor([0])]
            ys.append(torch.nn.functional.one_hot(torch.tensor(Y[0]), int(self.label_dims[1])))
            ys.append(torch.nn.functional.one_hot(torch.tensor(Y[1]), int(self.label_dims[2])))
            ys.append(torch.nn.functional.one_hot(torch.tensor(Y[2]), int(self.label_dims[3])))
            ys.append(torch.nn.functional.one_hot(torch.tensor(Y[3]), int(self.label_dims[4])))
            ys.append(torch.nn.functional.one_hot(torch.tensor(Y[4]), int(self.label_dims[5])))
            ys = torch.cat(ys).to(torch.float32).reshape(1,-1).cuda()
            
            p = vae.decoder.forward(zs, ys)
            return (N < p.cpu()).type(torch.float)
        
        def f_Y(N):
            m = torch.distributions.gumbel.Gumbel(torch.zeros(N.size(0)), torch.ones(N.size(0)))
            return torch.argmax(torch.add(torch.log(N), m.sample())).item()
        
        def f_Z(N):
            return N * sigma + mu
        
        def model(noise):
            N_X = pyro.sample( 'N_X', noise['N_X'] )
            # There are 5 Y variables and they will be
            # denoted using the index in the sequence 
            # that they are stored in as vae.label_names:
            # ['shape', 'scale', 'orientation', 'posX', 'posY']
            N_Y_1 = pyro.sample( 'N_Y_1', noise['N_Y_1'] )
            N_Y_2 = pyro.sample( 'N_Y_2', noise['N_Y_2'] )
            N_Y_3 = pyro.sample( 'N_Y_3', noise['N_Y_3'] )
            N_Y_4 = pyro.sample( 'N_Y_4', noise['N_Y_4'] )
            N_Y_5 = pyro.sample( 'N_Y_5', noise['N_Y_5'] ) 
            N_Z = pyro.sample( 'N_Z', noise['N_Z'] )
             
            Z = pyro.sample('Z', dist.Normal( f_Z( N_Z ), 1e-1) )
            Y_1_mu = f_Y(N_Y_1)
            Y_2_mu = f_Y(N_Y_2)
            Y_3_mu = f_Y(N_Y_3)
            Y_4_mu = f_Y(N_Y_4)
            Y_5_mu = f_Y(N_Y_5)
            Y_1 = pyro.sample('Y_1', dist.Normal( Y_1_mu, 1e-1) )
            Y_2 = pyro.sample('Y_2', dist.Normal( Y_2_mu, 1e-1) )
            Y_3 = pyro.sample('Y_3', dist.Normal( Y_3_mu, 1e-1) )
            Y_4 = pyro.sample('Y_4', dist.Normal( Y_4_mu, 1e-1) )
            Y_5 = pyro.sample('Y_5', dist.Normal( Y_5_mu, 1e-1) )
            Y_mu = (Y_1_mu, Y_2_mu, Y_3_mu, Y_4_mu, Y_5_mu)
            X = pyro.sample('X', dist.Normal( f_X( Y_mu, Z, N_X ), 1e-1) )
            
            noise_samples = N_X, (N_Y_1, N_Y_2, N_Y_3, N_Y_4, N_Y_5), N_Z
            variable_samples = X, (Y_1, Y_2, Y_3, Y_4, Y_5), Z
            
            return variable_samples, noise_samples
        
        self.model = model
        
        self.init_noise = {
            'N_X'   : dist.Uniform(torch.zeros(vae.image_dim), torch.ones(vae.image_dim)),
            'N_Z'   : dist.Normal(torch.zeros(vae.z_dim), torch.ones(vae.z_dim)),
            'N_Y_1' : dist.Uniform(torch.zeros(label_dims[1]),torch.ones(self.label_dims[1])),
            'N_Y_2' : dist.Uniform(torch.zeros(label_dims[2]),torch.ones(self.label_dims[2])),
            'N_Y_3' : dist.Uniform(torch.zeros(label_dims[3]),torch.ones(self.label_dims[3])),
            'N_Y_4' : dist.Uniform(torch.zeros(label_dims[4]),torch.ones(self.label_dims[4])),
            'N_Y_5' : dist.Uniform(torch.zeros(label_dims[5]),torch.ones(self.label_dims[5]))            
        }
        
        
        
    def update_noise_svi(self, obs_data):
        # assume all noise variables are normal distributions
        # use svi to find out the mu, sigma of the distributions
        # for the condition outlined in obs_data
        def guide(noise):
            # create params with constraints
            mu = {'N_X': pyro.param('N_X_mu', 0.5*torch.ones(self.image_dim),
                                    constraint = constraints.interval(0., 1.)),
                  'N_Z': pyro.param('N_Z_mu', torch.zeros(self.z_dim),
                                    constraint = constraints.interval(-3., 3.)),
                  'N_Y_1': pyro.param('N_Y_1_mu', 0.5*torch.ones(self.label_dims[1]),
                                    constraint = constraints.interval(0., 1.)),
                  'N_Y_2': pyro.param('N_Y_2_mu', 0.5*torch.ones(self.label_dims[2]),
                                    constraint = constraints.interval(0., 1.)),
                  'N_Y_3': pyro.param('N_Y_3_mu', 0.5*torch.ones(self.label_dims[3]),
                                    constraint = constraints.interval(0., 1.)),
                  'N_Y_4': pyro.param('N_Y_4_mu', 0.5*torch.ones(self.label_dims[4]),
                                    constraint = constraints.interval(0., 1.)),
                  'N_Y_5': pyro.param('N_Y_5_mu', 0.5*torch.ones(self.label_dims[5]),
                                    constraint = constraints.interval(0., 1.))
                }
            sigma = {'N_X': pyro.param('N_X_sigma', 0.1*torch.ones(self.image_dim),
                                    constraint = constraints.interval(0.0001, 0.5)),
                      'N_Z': pyro.param('N_Z_sigma', torch.ones(self.z_dim),
                                        constraint = constraints.interval(0.0001, 3.)),
                      'N_Y_1': pyro.param('N_Y_1_sigma', 0.1*torch.ones(self.label_dims[1]),
                                        constraint = constraints.interval(0.0001, 0.5)),
                      'N_Y_2': pyro.param('N_Y_2_sigma', 0.1*torch.ones(self.label_dims[2]),
                                        constraint = constraints.interval(0.0001, 0.5)),
                      'N_Y_3': pyro.param('N_Y_3_sigma', 0.1*torch.ones(self.label_dims[3]),
                                        constraint = constraints.interval(0.0001, 0.5)),
                      'N_Y_4': pyro.param('N_Y_4_sigma', 0.1*torch.ones(self.label_dims[4]),
                                        constraint = constraints.interval(0.0001, 0.5)),
                      'N_Y_5': pyro.param('N_Y_5_sigma', 0.1*torch.ones(self.label_dims[5]),
                                    constraint = constraints.interval(0.0001, 0.5))
                }
                  
            for noise_term in noise.keys():
                pyro.sample(noise_term, dist.Normal(mu[noise_term], sigma[noise_term]))
        
        obs_model = pyro.condition(self.model, obs_data)
        pyro.clear_param_store()
        svi = SVI(
            model= obs_model,
            guide= guide,
            optim= Adam({"lr": 1e-3}),
            loss=Trace_ELBO()
        )
        
        num_steps = 1000
        samples = defaultdict(list)
        for t in range(num_steps):
            svi.step(self.init_noise)
            
        # now determine new noise variables
        for noise in initial_noise.keys():
            mu = '{}_mu'.format(noise)
            sigma = '{}_sigma'.format(noise)
            samples[mu].append(pyro.param(mu).item())
            samples[sigma].append(pyro.param(sigma).item())
        means = {k: torch.mean(torch.cat(v, 0),0) for k, v in samples.items()}
        
        updated_noise = {
            'N_X': dist.Normal(means['N_X_mu'], means['N_X_sigma']),
            'N_Z': dist.Normal(means['N_Z_mu'], means['N_Z_sigma']),
            'N_Y_1': dist.Normal(means['N_Y_1_mu'], means['N_Y_1_sigma']),
            'N_Y_2': dist.Normal(means['N_Y_2_mu'], means['N_Y_2_sigma']),
            'N_Y_3': dist.Normal(means['N_Y_3_mu'], means['N_Y_3_sigma']),
            'N_Y_4': dist.Normal(means['N_Y_4_mu'], means['N_Y_4_sigma']),
            'N_Y_5': dist.Normal(means['N_Y_5_mu'], means['N_Y_5_sigma']),
        }
        
        return updated_noise
        
    def __call__(self):
        return self.model(self.init_noise)

But when I run this,

x, y = get_specific_data(cuda=True)
mu, sigma = vae.encoder.forward(x,vae.remap_y(y))
scm = SCM(vae, mu.cpu(), sigma.cpu())

cond_data = {}
for i in range(1, 6):
    cond_data["Y_{}".format(i)] = torch.tensor(y[0,i].cpu()).to(torch.float32)

cond_noise = scm.update_noise_svi(cond_data)

I get the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-278-e22c3f786618> in <module>
      4 
      5 
----> 6 cond_noise = scm.update_noise_svi(cond_data)

<ipython-input-276-8176258ce037> in update_noise_svi(self, obs_data)
    128         samples = defaultdict(list)
    129         for t in range(num_steps):
--> 130             svi.step(self.init_noise)
    131 
    132         # now determine new noise variables

~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
     97         # get loss and compute gradients
     98         with poutine.trace(param_only=True) as param_capture:
---> 99             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    100 
    101         params = set(site["value"].unconstrained()

~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    123         loss = 0.0
    124         # grab a trace from the generator
--> 125         for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
    126             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
    127             loss += loss_particle / self.num_particles

~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, *args, **kwargs)
    166         else:
    167             for i in range(self.num_particles):
--> 168                 yield self._get_trace(model, guide, *args, **kwargs)

~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, *args, **kwargs)
     50         """
     51         model_trace, guide_trace = get_importance_trace(
---> 52             "flat", self.max_plate_nesting, model, guide, *args, **kwargs)
     53         if is_validation_enabled():
     54             check_if_enumerated(guide_trace)

~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, *args, **kwargs)
     54         for site in model_trace.nodes.values():
     55             if site["type"] == "sample":
---> 56                 check_site_shape(site, max_plate_nesting)
     57         for site in guide_trace.nodes.values():
     58             if site["type"] == "sample":

~/anaconda2/envs/cs7180/lib/python3.6/site-packages/pyro/util.py in check_site_shape(site, max_plate_nesting)
    260                 '- enclose the batched tensor in a with plate(...): context',
    261                 '- .to_event(...) the distribution being sampled',
--> 262                 '- .permute() data dimensions']))
    263 
    264     # Check parallel dimensions on the left of max_plate_nesting.

ValueError: at site "N_X", invalid log_prob shape
  Expected [], actual [4096]
  Try one of the following fixes:
  - enclose the batched tensor in a with plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

Please let me know what the issue is as I am new to Pyro and I am not sure what is causing this issue.
Also, this is with regards to the project I am working for my course on Causal Inference in Machine Learning. Here is the link to course: https://github.com/robertness/causalML
Thank you for the support!

what is the shape of your data and noise['N_X']? the error is saying that you are trying to score a sample with a different shape than the distribution. i recommend stepping through a debugger, which will tell you exactly what you are passing into each distribution, making sure that your guide dist shapes match your model dist shapes and your observations are of the correct shape using to_event as necessary.

Is this of any help to you?

trace = pyro.poutine.trace(SCM(vae, mu, sigma)).get_trace()

trace.compute_log_prob() # optional, but allows printing of log_prob shapes

print(trace.format_shapes())

Output:

Trace Shapes:         
 Param Sites:         
Sample Sites:         
     N_X dist   4096 |
        value   4096 |
     log_prob   4096 |
   N_Y_1 dist      3 |
        value      3 |
     log_prob      3 |
   N_Y_2 dist      6 |
        value      6 |
     log_prob      6 |
   N_Y_3 dist     40 |
        value     40 |
     log_prob     40 |
   N_Y_4 dist     32 |
        value     32 |
     log_prob     32 |
   N_Y_5 dist     32 |
        value     32 |
     log_prob     32 |
     N_Z dist     50 |
        value     50 |
     log_prob     50 |
       Z dist 1   50 |
        value 1   50 |
     log_prob 1   50 |
     Y_1 dist        |
        value        |
     log_prob        |
     Y_2 dist        |
        value        |
     log_prob        |
     Y_3 dist        |
        value        |
     log_prob        |
     Y_4 dist        |
        value        |
     log_prob        |
     Y_5 dist        |
        value        |
     log_prob        |
       X dist 1 4096 |
        value 1 4096 |
     log_prob 1 4096 |

@jpchen these were generated by printing .shape for each noise variable:

----- guide -------- 
N_X torch.Size([4096]) 
N_Z torch.Size([50]) 
N_Y_1 torch.Size([3]) 
N_Y_2 torch.Size([6]) 
N_Y_3 torch.Size([40]) 
N_Y_4 torch.Size([32]) 
N_Y_5 torch.Size([32]) 
----- model ----- 
N_X torch.Size([4096]) 
N_Z torch.Size([50]) 
N_Y_1 torch.Size([3]) 
N_Y_2 torch.Size([6]) 
N_Y_3 torch.Size([40]) 
N_Y_4 torch.Size([32]) 
N_Y_5 torch.Size([32])

We get the same results if we print the batch_size of each underlying distributions