Define a custom ELBO function in pyro

Hello, I am trying to implement my own custom elbo by using likelihood, prior and kl divergence. How do I define these in my custom elbo?
Something like this: elbo = likelihood + prior - beta * kl_divergence

Considering that we have a VAE model defined by a model and guide function and the data with x as input and y as labels.
where, likelihood is BCE(reconstructed_x, x)
prior is log_standard_categorical(y)
and kl divergence is kld between two Dirichlet distributions.

This tutorial explains how loss functions can be customized. In particular, your case looks similar to the Beta VAE example.

1 Like