Problems with masking: "Found vars in model but not guide"

I am masking a small number of observations for some participants (but not all) from my inference. As soon as I use a mask however, I get the warning “Found vars in model but not guide:”, listing all sample sites. Weirdly, I get this error even if I always use an obs_mask with only 1s.

    def model(self):
        # define hyper priors over model parameters
        # prior over sigma of a Gaussian is a Gamma distribution
        a = pyro.param('a', torch.ones(self.n_parameters), constraint=dist.constraints.positive)
        lam = pyro.param('lam', torch.ones(self.n_parameters), constraint=dist.constraints.positive)
        tau = pyro.sample('tau', dist.Gamma(a, a/lam).to_event(1)) # Why a/lam?
        
        sig = 1/torch.sqrt(tau) # Gaus sigma

        # each model parameter has a hyperprior defining group level mean
        # in the form of a Normal distribution
        m = pyro.param('m', torch.zeros(self.n_parameters))
        s = pyro.param('s', torch.ones(self.n_parameters), constraint=dist.constraints.positive)
        mu = pyro.sample('mu', dist.Normal(m, s*sig).to_event(1)) # Gauss mu, wieso s*sig?

        with pyro.plate('subject', self.n_subjects) as ind:
    
            # draw parameters from Normal and transform (for numeric reasons)
            base_dist = dist.Normal(0., 1.).expand_by([self.n_parameters]).to_event(1)
            transform = dist.transforms.AffineTransform(mu, sig)
            locs = pyro.sample('locs', dist.TransformedDistribution(base_dist, [transform]))
    
            if locs.ndim == 2:
                locs = locs[None, :]
                
            self.agent.reset(locs)
            
            n_particles = locs.shape[0]
            t = -1
            for tau in pyro.markov(range(self.trials)):
    
                trial = self.data["Trialsequence"][tau]
                blocktype = self.data["Blocktype"][tau]
                
                if all([self.data["Blockidx"][tau][i] <= 5 for i in range(self.n_subjects)]):
                    day = 1
                    
                elif all([self.data["Blockidx"][tau][i] > 5 for i in range(self.n_subjects)]):
                    day = 2
                    
                else:
                    raise Exception("error.")
                
                if all([trial[i] == -1 for i in range(self.n_subjects)]):
                    "Beginning of new block of experiment"
                    self.agent.update(torch.tensor([-1]), torch.tensor([-1]), torch.tensor([-1]), day=day, trialstimulus=trial)
                    
                else:
                    current_choice = self.data["Choices"][tau]
                    outcome = self.data["Outcomes"][tau]
                
                if all([trial[i] > 10 for i in range(self.n_subjects)]):
                    t+=1
                    option1, option2 = self.agent.find_resp_options(trial)
                    
                    probs = self.agent.compute_probs(trial, day)
                    
                    choices = torch.tensor([0 if current_choice[idx] == option1[idx] else 1 for idx in range(len(current_choice))])
                    obs_mask = torch.tensor([0 if cc == -10 else 1 for cc in current_choice ]).type(torch.bool)
                    
                if all([trial[i] != -1 for i in range(self.n_subjects)]):
                    self.agent.update(current_choice, outcome, blocktype, day=day, trialstimulus=trial)

                "Sample if no error was performed"
                if all([trial[i] > 10 for i in range(self.n_subjects)]):
                    pyro.sample('res_{}'.format(t), dist.Categorical(probs=probs), \
                                obs = choices.broadcast_to(n_particles, self.n_subjects), \
                                obs_mask = obs_mask.broadcast_to(n_particles, self.n_subjects))

    def guide(self):
        trns = torch.distributions.biject_to(dist.constraints.positive)
    
        # define mean vector and covariance matrix of multivariate normal
        m_hyp = pyro.param('m_hyp', torch.zeros(2*self.n_parameters))
        st_hyp = pyro.param('scale_tril_hyp',
                       torch.eye(2*self.n_parameters),
                       constraint=dist.constraints.lower_cholesky)
        
        # set hyperprior to be multivariate normal
        hyp = pyro.sample('hyp',
                     dist.MultivariateNormal(m_hyp, scale_tril=st_hyp),
                     infer={'is_auxiliary': True})
    
        unc_mu = hyp[..., :self.n_parameters]
        unc_tau = hyp[..., self.n_parameters:]

        c_tau = trns(unc_tau)

        ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau)
        ld_tau = dist.util.sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1)
    
        # some numerics tricks
        mu = pyro.sample("mu", dist.Delta(unc_mu, event_dim=1))
        tau = pyro.sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1))
    
        m_locs = pyro.param('m_locs', torch.zeros(self.n_subjects, self.n_parameters))
        st_locs = pyro.param('scale_tril_locs',
                        torch.eye(self.n_parameters).repeat(self.n_subjects, 1, 1),
                        constraint=dist.constraints.lower_cholesky)
        
        with pyro.plate('subject', self.n_subjects):
            locs = pyro.sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))

        return {'tau': tau, 'mu': mu, 'locs': locs}

I am not sure what to do about my guide. What can I do?

it’d probably be best if you provide a small minimal working example. wer hat die zeit das ganze genau durchzulesen?

Can we stick to English please?
Providing a minimal working example will take a while since the project is pretty big and convoluted. In the meantime, is there anything you could point me to? I think the problem is still quite clear, even if I use obs_mask = torch.ones_like(choices.broadcast_to(n_particles, self.n_subjects), dtype = torch.bool) as a mask, I get the warning

UserWarning: Found vars in model but not guide: {'res_800_unobserved', 'res_930_unobserved', 'res_199_unobserved', ...

and so listing each sampling site. The pyro tutorials don’t really cover masking or how it relates to the guide.

what is your precise intention? as it states in the docs

  • obs_mask (bool or Tensor) – Optional boolean tensor mask of shape broadcastable with fn.batch_shape. If provided, events with mask=True will be conditioned on obs and remaining events will be imputed by sampling. This introduces a latent sample site named name + "_unobserved" which should be used by guides.

if you want to use obs_mask you need to introduce the corresponding sample sites into your guide.

Yes thank you, that was the problem. I did not realize that the .mask() method of a distribution works differently from the obs_mask argument of pyro.sample(). obs_mask performs Bayesian imputation, which is why unobserved sites have to be modeled by the guide, while the .mask() method of a distribution allows to simply ignore an observation, without imputation.