Sampling a global latent variable conditioned on a single data point

Dear all,

Suppose we are given a model

\begin{align}p(\theta)\cdot \prod_{i=1}^n p(\mathbf{x}_i|\mathbf{z}_i, \theta)p(\mathbf{z}_i|\theta),\end{align}

where \mathbf{z}_i and \theta are local and global latent variables, respectively.
I am wondering if we can efficiently sample (with MCMC) from the posterior latent p(\theta, \mathbf{z}_i| \mathbf{x}_i) conditioned on a single point \mathbf{x}_i rather than p(\theta, \mathbf{z}_i|\mathbf{x}_{1:n}) conditioned on the whole data \mathbf{x}_{1:n}.
In short, the question is: is there any simple way to obtain m samples \left[\{(\theta_i^{(j)}, \mathbf{z}_i^{(j)})\}_{i=1}^n\right]_{j=1}^m with (\theta_i^{(j)}, \mathbf{z}_i^{(j)})\sim p(\theta, \mathbf{z}_i|\mathbf{x}_{i})?
(i.e., the expected size of the returned sample is m x n x d with d being \theta's dimensionality).

Let’s say the observation site is defined by pyro.plate; If the global latent is absent, then sampling is straightforward, and we don’t need to write explicitly write a for-loop (this is what I meant by “efficiently” above). Would the same be possible even with a global latent (without defining another model sampling a single data point)? If not, what would be the pyro way to implement this?

Thank you for reading the question. I would appreciate any feedback!

these are n distinct posteriors so would require n distinct mcmc runs. are you sure you want to do this?

Thanks for the quick response. Yes, this is inevitable for my purpose. More specifically, I am trying to implement the score function (of the marginal) for latent variable models:

\nabla_{\mathbf{x}} \log p(\mathbf{x}) = \mathbb{E}_{\mathbf{z, \theta}|\mathbf{x}} \left[\nabla_{\mathbf{x}} \log p(\mathbf{x}|\mathbf{z,\theta})\right].

I need to evaluate the score function on a test dataset, and thus need to evaluate the posterior expectation n times.

your best bet would probably be to use numpyro + HMC + vmap so that you can do HMC in parallel for different mini-batches of {x_i}

vmap sounds like a neat solution!
My use of posterior expectation is a kind of niche, and so I don’t expect there’s a neater solution than this (though it would be great if anyone could point out any other way).
Thank you!