I am using the pyro.infer.csis
class to amortize the cost of inference by learning a neural guide to provide proposal distributions to be weighted in an importance sampling procedure. However, when inspecting how importance weights are computed in pyro.infer.importance
, by setting log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
doesn’t mean that the sequential part is missing? I mean, shouldn’t a weight be computed for each latent variable in an iterative way (i.e. progressively adding more latent variables to be considered for computing the weight)?
To address this, you would indeed need to iteratively compute weights as you add more latent variables. This sequential approach ensures that the importance weights accurately reflect the contribution of each latent variable to the overall probability distribution.
1 Like
Exactly, I ended up looping over all latent variables (i.e. after running the model) and iteratively making mask = False
on sites that shouldn’t contribute to log_prob_sum()