Denoising VAE

Hi, I’m using the latest pyro and tutorials.

In another place I have a BVAE pytorch implementation that trains on audio waveforms and denoises them by losing information during reconstruction.

The training step is as follows:

def training_step(self, x, batch_idx):
        x_encoded = self.encoder(x)
        mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)
        std = torch.exp(log_var / 2)
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        x_hat = self.decoder(z)
        recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)
        kl = self.kl_divergence(z, mu, std)
        elbo = (kl - recon_loss)
        elbo = elbo.mean()
            'elbo': elbo,
            'kl': kl.mean(),
            'recon_loss': recon_loss.mean(), 
            'reconstruction': recon_loss.mean(),
            'kl': kl.mean(),
        return elbo

So as you can see to calculate the loss x_hat is not sampled out of some learned distribution.

I am learning pyro and trying to hack the MNIST SVI examples to accept my data and do the same training.

In a typical example, such as the VAE with Flow prior tutorial, I find a decoder with the forward function defined as:

    def forward(self, z: Tensor):
        phi = self.hyper(z)
        rho = torch.sigmoid(phi)
        return pyro.distributions.Bernoulli(rho).to_event(1)

Similarly, as the model in the VAE example:

    # define the model p(x|z)p(z)
    def model(self, x):
        pyro.module("decoder", self.decoder)
        with pyro.plate("data", x.shape[0]):
            z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
            z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
            z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
            loc_img = self.decoder(z)
            pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 3250))

Which works on the MNIST, as per the discussions related to that - isn’t what I am looking for, as I do not sample out of a distribution. I need something like…

       pyro.register_samples_as_if_sampled_from_dist("obs", loc_img)

Or some way to transform the observation model to sampled from a normal distribution.

Any advice on how to port the pytorch training step to pyro would be appreciated.