Hi,
my likelihood is a mixture model of three reinforcement learning agents with planning depth d=(1,2,3) combined via categorical distribution. The agents play a planning game with trials (b) each with 3 sequential actions (t). I am inferring the probability over planning depths d from true participant choices with SVI and TraceEnum_ELBO. Here is my model:
class Inferrer(object):
def __init__(self, agent, stimuli, responses, mask):
self.agent = agent
self.nsub, self.nblk = responses.shape[:2]
self.responses = responses
self.mask = mask
self.N = mask.sum(dim=0)
self.depth_transition = zeros(2, 3, 2, 3)
self.depth_transition[0, :, 0] = tensor([1., 0., 0.])
self.depth_transition[0, :, 1] = tensor([.5, .5, 0.])
self.depth_transition[1] = tensor([1., 0., 0.])
self.states = stimuli['states']
self.configs = stimuli['configs']
self.conditions = stimuli['conditions']
def model_static(self):
"""Assume static prior over planning depth per condition.
"""
agent = self.agent
np = agent.np # number of parameters
nblk = self.nblk # number of mini-blocks, i.e. trials
nsub = self.nsub # number of subjects
# define hyper priors over model parameters.
# define prior uncertainty over model parameters and subjects
a = param('a', 2*ones(np), constraint=constraints.positive)
r = param('r', 2*ones(np), constraint=constraints.positive)
tau = sample('tau', dist.Gamma(a, r/a).to_event(1))
sig = 1./torch.sqrt(tau)
# define prior mean over model parameters
m = param('m', zeros(np))
s = param('s', ones(np), constraint=constraints.positive)
mu = sample("mu", dist.Normal(m, s*sig).to_event(1))
alphas_c1 = param("alphas_c1", ones(2), constraint=constraints.positive)
alphas_c2 = param("alphas_c2", ones(3), constraint=constraints.positive)
with plate('subjects', nsub):
# locs.shape : [num_particles, num_subjects, 3]
locs = sample("locs", dist.Normal(mu, sig).to_event(1))
# define priors over planning depth
probs_c1 = sample("probs_c1", dist.Dirichlet(alphas_c1)) # prior over planning depths if remaining steps = 2 (i.e. for a1 in 3-step miniblock and a0 in 2-step miniblock)
probs_c2 = sample("probs_c2", dist.Dirichlet(alphas_c2)) # prior over planning depths if remaining steps = 3 (i.e. for a0 in 3-step miniblock)
agent.set_parameters(locs)
shape = agent.batch_shape
tmp = zeros(shape + (3,))
tmp[..., :2] = probs_c1
priors = torch.stack([tmp, probs_c2], -2)
for b in markov(range(nblk)):
conditions = self.conditions[..., b]
states = self.states[:, b]
responses = self.responses[:, b]
max_trials = conditions[-1]
tm = self.depth_transition[:, :, max_trials - 2]
for t in markov(range(3)):
if t == 0:
res = None
probs = priors[..., range(nsub), max_trials - 2, :]
else:
res = responses[:, t-1]
probs = tm[t-1, -1]
agent.update_beliefs(b, t, states[:, t], conditions, res)
agent.plan_actions(b, t)
valid = self.mask[:, b, t]
N = self.N[b, t]
res = responses[valid, t]
logits = agent.logits[-1][..., valid, :]
probs = probs[..., valid, :]
if t == 2:
agent.update_beliefs(b, t + 1, states[:, t + 1], conditions, responses[:, t])
if N > 0:
with plate('responses_{}_{}'.format(b, t), N):
d = sample('d_{}_{}'.format(b, t),
dist.Categorical(probs),
infer={"enumerate": "parallel"})
sample('obs_{}_{}'.format(b, t),
dist.Bernoulli(logits=Vindex(logits)[..., d]),
obs=res)
def guide_static(self):
npar = self.agent.np # number of parameters
nsub = self.nsub # number of subjects
m_hyp = param('m_hyp', zeros(2*npar))
st_hyp = param('scale_tril_hyp',
torch.eye(2*npar),
constraint=constraints.lower_cholesky)
hyp = sample('hyp', dist.MultivariateNormal(m_hyp, scale_tril=st_hyp),
infer={'is_auxiliary': True})
unc_mu = hyp[..., :npar]
unc_tau = hyp[..., npar:]
trns_tau = biject_to(constraints.positive)
c_tau = trns_tau(unc_tau)
ld_tau = trns_tau.inv.log_abs_det_jacobian(c_tau, unc_tau)
ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1)
mu = sample("mu", dist.Delta(unc_mu, event_dim=1))
tau = sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1))
m_locs = param('m_locs', zeros(nsub, npar))
st_locs = param('s_locs', torch.eye(npar).repeat(nsub, 1, 1),
constraint=constraints.lower_cholesky)
alpha1 = param('alpha1', ones(nsub, 2), constraint=constraints.positive)
alpha2 = param('alpha2', ones(nsub, 3), constraint=constraints.positive)
with plate('subjects', nsub):
locs = sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
probsc1 = sample("probs_c1", dist.Dirichlet(alpha1))
probsc2 = sample("probs_c2", dist.Dirichlet(alpha2))
return {'mu': mu, 'tau': tau, 'locs': locs, 'pc1': probsc1, 'pc2': probsc2}
Inference with this setup works well. I can retrieve posterior distributions over planning depths for each trials by using marginal:
post_depth = elbo.compute_marginals(self.model, self.guide)
Now I would like to sample posterior predictions from the likelihood to investigate model fit. However with my attempt, I don’t get planning depth distributions per trial and I don’t know why. I am doing:
guide_sample = self.guide()
# condition the likelihood (model) on one sample from the approx. posterior
conditioned_model = pyro.condition(self.model,
data = {'locs': guide_sample['locs'].detach(), 'probs_c1': guide_sample['pc1'].detach(), 'probs_c2': guide_sample['pc2'].detach()})
# run conditioned model once
trace = pyro.poutine.trace(conditioned_model).get_trace()
In the trace, the parameters of the subject plate are correctly sampled. However, distribution over discrete planning depths is the same for all trials. It is just the same as the correponding hyperparameters probs_c1 and probs_c2 (single distributions per condition of the task). The trials-wise fitted distributions get lost somehow. What am I missing? Why is the probability over planning depth d not sampled from posterior of each individual trial? Do I have to use marginal again during sampling?