Likelihood sampling with discrete latent variable

Hi,
my likelihood is a mixture model of three reinforcement learning agents with planning depth d=(1,2,3) combined via categorical distribution. The agents play a planning game with trials (b) each with 3 sequential actions (t). I am inferring the probability over planning depths d from true participant choices with SVI and TraceEnum_ELBO. Here is my model:

class Inferrer(object):

    def __init__(self, agent, stimuli, responses, mask):
        self.agent = agent
        self.nsub, self.nblk = responses.shape[:2]

        self.responses = responses
        self.mask = mask
        self.N = mask.sum(dim=0)

        self.depth_transition = zeros(2, 3, 2, 3)
        self.depth_transition[0, :, 0] = tensor([1., 0., 0.])
        self.depth_transition[0, :, 1] = tensor([.5, .5, 0.])
        self.depth_transition[1] = tensor([1., 0., 0.])

        self.states = stimuli['states']
        self.configs = stimuli['configs']
        self.conditions = stimuli['conditions']

    def model_static(self):
        """Assume static prior over planning depth per condition.
        """
        agent = self.agent
        np = agent.np  # number of parameters

        nblk = self.nblk  # number of mini-blocks, i.e. trials
        nsub = self.nsub  # number of subjects

        # define hyper priors over model parameters.

        # define prior uncertainty over model parameters and subjects
        a = param('a', 2*ones(np), constraint=constraints.positive)
        r = param('r', 2*ones(np), constraint=constraints.positive)
        tau = sample('tau', dist.Gamma(a, r/a).to_event(1))

        sig = 1./torch.sqrt(tau)
        # define prior mean over model parameters
        m = param('m', zeros(np))
        s = param('s', ones(np), constraint=constraints.positive)
        mu = sample("mu", dist.Normal(m, s*sig).to_event(1))

        alphas_c1 = param("alphas_c1", ones(2), constraint=constraints.positive)
        alphas_c2 = param("alphas_c2", ones(3), constraint=constraints.positive)

        with plate('subjects', nsub):
            # locs.shape : [num_particles, num_subjects, 3]
            locs = sample("locs", dist.Normal(mu, sig).to_event(1))
            # define priors over planning depth
            probs_c1 = sample("probs_c1", dist.Dirichlet(alphas_c1)) # prior over planning depths if remaining steps = 2 (i.e. for a1 in 3-step miniblock and a0 in 2-step miniblock)
            probs_c2 = sample("probs_c2", dist.Dirichlet(alphas_c2)) # prior over planning depths if remaining steps = 3 (i.e. for a0 in 3-step miniblock)

        agent.set_parameters(locs)

        shape = agent.batch_shape

        tmp = zeros(shape + (3,))
        tmp[..., :2] = probs_c1

        priors = torch.stack([tmp, probs_c2], -2)

        for b in markov(range(nblk)):
            conditions = self.conditions[..., b]
            states = self.states[:, b]
            responses = self.responses[:, b]
            max_trials = conditions[-1]

            tm = self.depth_transition[:, :, max_trials - 2]
            for t in markov(range(3)):

                if t == 0:
                    res = None
                    probs = priors[..., range(nsub), max_trials - 2, :]
                else:
                    res = responses[:, t-1]
                    probs = tm[t-1, -1]

                agent.update_beliefs(b, t, states[:, t], conditions, res)
                agent.plan_actions(b, t)

                valid = self.mask[:, b, t]
                N = self.N[b, t]
                res = responses[valid, t]

                logits = agent.logits[-1][..., valid, :]
                probs = probs[..., valid, :]

                if t == 2:
                    agent.update_beliefs(b, t + 1, states[:, t + 1], conditions, responses[:, t])

                if N > 0:
                    with plate('responses_{}_{}'.format(b, t), N):
                        d = sample('d_{}_{}'.format(b, t),
                                   dist.Categorical(probs),
                                   infer={"enumerate": "parallel"})

                        sample('obs_{}_{}'.format(b, t),
                               dist.Bernoulli(logits=Vindex(logits)[..., d]),
                               obs=res)

    def guide_static(self):

        npar = self.agent.np  # number of parameters
        nsub = self.nsub  # number of subjects

        m_hyp = param('m_hyp', zeros(2*npar))
        st_hyp = param('scale_tril_hyp',
                       torch.eye(2*npar),
                       constraint=constraints.lower_cholesky)
        hyp = sample('hyp', dist.MultivariateNormal(m_hyp, scale_tril=st_hyp),
                     infer={'is_auxiliary': True})

        unc_mu = hyp[..., :npar]
        unc_tau = hyp[..., npar:]

        trns_tau = biject_to(constraints.positive)

        c_tau = trns_tau(unc_tau)

        ld_tau = trns_tau.inv.log_abs_det_jacobian(c_tau, unc_tau)
        ld_tau = sum_rightmost(ld_tau, ld_tau.dim() - c_tau.dim() + 1)

        mu = sample("mu", dist.Delta(unc_mu, event_dim=1))
        tau = sample("tau", dist.Delta(c_tau, log_density=ld_tau, event_dim=1))

        m_locs = param('m_locs', zeros(nsub, npar))
        st_locs = param('s_locs', torch.eye(npar).repeat(nsub, 1, 1),
                        constraint=constraints.lower_cholesky)

        alpha1 = param('alpha1', ones(nsub, 2), constraint=constraints.positive)
        alpha2 = param('alpha2', ones(nsub, 3), constraint=constraints.positive)

        with plate('subjects', nsub):
            locs = sample("locs", dist.MultivariateNormal(m_locs, scale_tril=st_locs))
            probsc1 = sample("probs_c1", dist.Dirichlet(alpha1))
            probsc2 = sample("probs_c2", dist.Dirichlet(alpha2))

        return {'mu': mu, 'tau': tau, 'locs': locs, 'pc1': probsc1, 'pc2': probsc2}

Inference with this setup works well. I can retrieve posterior distributions over planning depths for each trials by using marginal:

post_depth = elbo.compute_marginals(self.model, self.guide)

Now I would like to sample posterior predictions from the likelihood to investigate model fit. However with my attempt, I don’t get planning depth distributions per trial and I don’t know why. I am doing:

guide_sample = self.guide()
# condition the likelihood (model) on one sample from the approx. posterior
conditioned_model = pyro.condition(self.model, 
data = {'locs': guide_sample['locs'].detach(), 'probs_c1': guide_sample['pc1'].detach(), 'probs_c2': guide_sample['pc2'].detach()})
# run conditioned model once
trace = pyro.poutine.trace(conditioned_model).get_trace()

In the trace, the parameters of the subject plate are correctly sampled. However, distribution over discrete planning depths is the same for all trials. It is just the same as the correponding hyperparameters probs_c1 and probs_c2 (single distributions per condition of the task). The trials-wise fitted distributions get lost somehow. What am I missing? Why is the probability over planning depth d not sampled from posterior of each individual trial? Do I have to use marginal again during sampling?

I now implemented a solution using guide() to sample the locs parameters from the subject plate and then using elbo.compute_marginals() to sample distributions over depth per trial. However, I did not find any way to condition my model on these samples and let it run with get_trace(). So, I wrote a function sample_posterior_predictions() where I do the above mentioned sampling and then I basically copied my whole model into it so that it runs “manually” with the sampled parameters (starting with: agent.set_parameters()).