I am masking a small number of observations for some participants (but not all) from my inference. As soon as I use a mask however, I get the warning “Found vars in model but not guide:”, listing all sample sites. Weirdly, I get this error even if I always use an obs_mask with only 1s.
def model(self):
# define hyper priors over model parameters
# prior over sigma of a Gaussian is a Gamma distribution
a = pyro.param('a', torch.ones(self.n_parameters), constraint=dist.constraints.positive)
lam = pyro.param('lam', torch.ones(self.n_parameters), constraint=dist.constraints.positive)
tau = pyro.sample('tau', dist.Gamma(a, a/lam).to_event(1)) # Why a/lam?
sig = 1/torch.sqrt(tau) # Gaus sigma
# each model parameter has a hyperprior defining group level mean
# in the form of a Normal distribution
m = pyro.param('m', torch.zeros(self.n_parameters))
s = pyro.param('s', torch.ones(self.n_parameters), constraint=dist.constraints.positive)
mu = pyro.sample('mu', dist.Normal(m, s*sig).to_event(1)) # Gauss mu, wieso s*sig?
with pyro.plate('subject', self.n_subjects) as ind:
# draw parameters from Normal and transform (for numeric reasons)
base_dist = dist.Normal(0., 1.).expand_by([self.n_parameters]).to_event(1)
transform = dist.transforms.AffineTransform(mu, sig)
locs = pyro.sample('locs', dist.TransformedDistribution(base_dist, [transform]))
if locs.ndim == 2:
locs = locs[None, :]
self.agent.reset(locs)
n_particles = locs.shape[0]
t = -1
for tau in pyro.markov(range(self.trials)):
trial = self.data["Trialsequence"][tau]
blocktype = self.data["Blocktype"][tau]
if all([self.data["Blockidx"][tau][i] <= 5 for i in range(self.n_subjects)]):
day = 1
elif all([self.data["Blockidx"][tau][i] > 5 for i in range(self.n_subjects)]):
day = 2
else:
raise Exception("error.")
if all([trial[i] == -1 for i in range(self.n_subjects)]):
"Beginning of new block of experiment"
self.agent.update(torch.tensor([-1]), torch.tensor([-1]), torch.tensor([-1]), day=day, trialstimulus=trial)
else:
current_choice = self.data["Choices"][tau]
outcome = self.data["Outcomes"][tau]
if all([trial[i] > 10 for i in range(self.n_subjects)]):
t+=1
option1, option2 = self.agent.find_resp_options(trial)
probs = self.agent.compute_probs(trial, day)
choices = torch.tensor([0 if current_choice[idx] == option1[idx] else 1 for idx in range(len(current_choice))])
obs_mask = torch.tensor([0 if cc == -10 else 1 for cc in current_choice ]).type(torch.bool)
if all([trial[i] != -1 for i in range(self.n_subjects)]):
self.agent.update(current_choice, outcome, blocktype, day=day, trialstimulus=trial)
"Sample if no error was performed"
if all([trial[i] > 10 for i in range(self.n_subjects)]):
pyro.sample('res_{}'.format(t), dist.Categorical(probs=probs), \
obs = choices.broadcast_to(n_particles, self.n_subjects), \
obs_mask = obs_mask.broadcast_to(n_particles, self.n_subjects))
def guide(self):
trns = torch.distributions.biject_to(dist.constraints.positive)
# define mean vector and covariance matrix of multivariate normal
m_hyp = pyro.param('m_hyp', torch.zeros(2*self.n_parameters))
st_hyp = pyro.param('scale_tril_hyp',
torch.eye(2*self.n_parameters),
constraint=dist.constraints.lower_cholesky)
# set hyperprior to be multivariate normal
hyp = pyro.sample('hyp',
dist.MultivariateNormal(m_hyp, scale_tril=st_hyp),
infer={'is_auxiliary': True})
unc_mu = hyp[..., :self.n_parameters]
unc_tau = hyp[..., self.n_parameters:]
c_tau = trns(unc_tau)
ld_tau = trns.inv.log_abs_det_jacobian(c_tau, unc_tau)
ld_tau = dist.util.sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1)
# some numerics tricks
mu = pyro.sample("mu", dist.Delta(unc_mu, event_dim=1))
tau = pyro.sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1))
m_locs = pyro.param('m_locs', torch.zeros(self.n_subjects, self.n_parameters))
st_locs = pyro.param('scale_tril_locs',
torch.eye(self.n_parameters).repeat(self.n_subjects, 1, 1),
constraint=dist.constraints.lower_cholesky)
with pyro.plate('subject', self.n_subjects):
locs = pyro.sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
return {'tau': tau, 'mu': mu, 'locs': locs}
I am not sure what to do about my guide. What can I do?