How to join two VAEs in pyro to create POE, is there any example?
Well if your product-of-experts VAE comprises a collection of encoders that target a common latent state via diagonal normals, say
encoders : List[Callable[[torch.Tensor], dist.Normal]]
then you can manually fuse the normals and draw a z
sample from the product-of-experts as follows:
expert_opinons: List[dist.Normal] = [e(x) for e in encoders]
precision = sum(n.scale ** (-2) for n in expert_opinions)
mean = sum(n.loc / n.scale ** 2 for n in expert_opinions) / precision
product_opinion = dist.Normal(mean, precision.rsqrt())
z = product_opinion.rsample()