Hi, I’m using the latest pyro and tutorials.
In another place I have a BVAE pytorch implementation that trains on audio waveforms and denoises them by losing information during reconstruction.
The training step is as follows:
def training_step(self, x, batch_idx):
x_encoded = self.encoder(x)
mu, log_var = self.fc_mu(x_encoded), self.fc_var(x_encoded)
std = torch.exp(log_var / 2)
q = torch.distributions.Normal(mu, std)
z = q.rsample()
x_hat = self.decoder(z)
recon_loss = self.gaussian_likelihood(x_hat, self.log_scale, x)
kl = self.kl_divergence(z, mu, std)
elbo = (kl - recon_loss)
elbo = elbo.mean()
self.log_dict({
'elbo': elbo,
'kl': kl.mean(),
'recon_loss': recon_loss.mean(),
'reconstruction': recon_loss.mean(),
'kl': kl.mean(),
})
return elbo
So as you can see to calculate the loss x_hat is not sampled out of some learned distribution.
I am learning pyro and trying to hack the MNIST SVI examples to accept my data and do the same training.
In a typical example, such as the VAE with Flow prior tutorial, I find a decoder with the forward function defined as:
def forward(self, z: Tensor):
phi = self.hyper(z)
rho = torch.sigmoid(phi)
return pyro.distributions.Bernoulli(rho).to_event(1)
Similarly, as the model in the VAE example:
# define the model p(x|z)p(z)
def model(self, x):
pyro.module("decoder", self.decoder)
with pyro.plate("data", x.shape[0]):
z_loc = x.new_zeros(torch.Size((x.shape[0], self.z_dim)))
z_scale = x.new_ones(torch.Size((x.shape[0], self.z_dim)))
z = pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1))
loc_img = self.decoder(z)
pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x.reshape(-1, 3250))
Which works on the MNIST, as per the discussions related to that - isn’t what I am looking for, as I do not sample out of a distribution. I need something like…
pyro.register_samples_as_if_sampled_from_dist("obs", loc_img)
Or some way to transform the observation model to sampled from a normal distribution.
Any advice on how to port the pytorch training step to pyro would be appreciated.
Nick