How to implement beta vae and MSE loss


I’m still new to Pyro and I’m trying to get my head around it. At the moment I would like to implement a simple beta-VAE and while I know how I would do it in PyTorch I’m missing something to produce it with Pyro (perhaps it is better to just use vanilla pytorch?)

For instance, I want to use an MSE Loss as well as KL and beta. In Pytorch I would do it like this:

mse = F.mse_loss(x_out, target, size_average=False)
kl = -0.5 * torch.sum(1 + z_logvar - (z_mean**2) - torch.exp(z_logvar)) 
loss = ((alpha * mse) + (beta * kl)) / x_out.size(0)    

So loss is some weighted average of KL and a MSE.

What do you think is the best way of approaching this?

you can use the scale poutine to provide different scale factors for different terms in the elbo as needed:

def model(data, beta=0.5):
    with pyro.poutine.scale(scale=beta):
        pyro.sample("z", ...)
    pyro.sample("data", ..., obs=data)

def guide(data, beta=0.5):
    with pyro.poutine.scale(scale=beta):
        pyro.sample("z", ...)

to add a MSE term to the elbo see here