Collapsed Variational Inference

Like collapsed Gibbs sampling, Collapsed variational inference is useful when a parameter needs to be dropped. It is especially useful in non-parametric priors such as Normalized Gamma Process, or NGIG. Taking a simple example in the common Dirichlet Mixture models using the stick-breaking formula, we collapse (marginalize) out the stick breaking parameters. This, however, changes the update equations for the variational distributions of the global and latent variables to rely on past assignments (like in MCMC where the current observations is conditioned on past observations).


A simple model becomes something like this:
x ~ F(x|theta_z_k) theta_k ~ Dirichlet(alpha) z_k ~ categorical(1/K)

Following the documentation based on the Gaussian Mixture and our model is a Multinomial Dirichlet Process Mixture model:

#I imagine this part remains the same
def model(data):

    clus_num = [1] #number of obs in each cluster
    theta = pyro.sample("dirichlet", dist.Dirichlet(torch.ones(alpha/K)) # sample the global RV  -- cluster parameters
    for i in clus_num:
        locs = Chinese restaurant process update 
    with pyro.iarange('data', len(data)):
        # Local variables.
        pyro.sample('obs', dist.Dirichlet(theta [locs]), obs=data)

def guide(data)
q(theta_{z_k}) = Entropy of Dirichlet Distribution

#For the assignments this is more difficult:

#two parts: first is a  joint distribution
q(z_n) = E[log(x,theta_z_n)] - E[[log(z|z_i)]

#second part - we can approximate 

log p(z_n | z _{not n}) = ln E[n_k_{not current cluster}  #where n_k is current amount of observations in cluster k without the the current observation so clus_num[k] - 1 

#this is approximated as : 
E[clus_num[k] = sum( q_z_i(z=k))

So my question is, in terms of using Pyro how can we use pyro.sample to sample on past events, with the except of the current one.