I’m still new to Pyro and I’m trying to get my head around it. At the moment I would like to implement a simple beta-VAE and while I know how I would do it in PyTorch I’m missing something to produce it with Pyro (perhaps it is better to just use vanilla pytorch?)
For instance, I want to use an MSE Loss as well as KL and beta. In Pytorch I would do it like this:
mse = F.mse_loss(x_out, target, size_average=False) kl = -0.5 * torch.sum(1 + z_logvar - (z_mean**2) - torch.exp(z_logvar)) loss = ((alpha * mse) + (beta * kl)) / x_out.size(0)
So loss is some weighted average of KL and a MSE.
What do you think is the best way of approaching this?