I need to modify my variational inference code to optimize a slightly different objective than ELBO.

Typically, we want to maximize the marginal log likelihood of the data. The log probability of an individual data point x can be written as

\log p(x) = \mathbf{KL}(q(z \mid x) || p(z\mid x)) + \mathbb{E}_q[\log p(x, z) - \log q(z \mid x)].

Ordinarily, we cannot compute the exact log probability because we cannot compute the KL divergence between the approximate posterior and the true posterior. So instead, we maximize the ELBO, which serves as a lower bound.

In my setting, we are able to compute the true posterior for a subset of data points. For only those data points, I want to use the log probability \log p(x) as the loss rather than the ELBO. Is there a straightforward way to modify Pyro’s ELBO classes to do this or am I better off writing the training from scratch?