VAE: Turn Bayesian Encoder into a normal Encoder?

Hey everyone,

I am following the Pyro VAE tutorial here, and I’d like to replace the Bayesian Encoder with a normal Encoder, i.e. I don’t want to learn the variances in all Encoder weight, only in the prediction for z, which are z_loc and z_scale.

However, when I uncomment the line pyro.module("encoder", self.encoder), the model does not seem to learn anything.
Does someone know how to achieve that?

I am not exactly sure about what you are asking but maybe this is helpful.

  1. pyro.module registers the encoder in Pyro so that the parameters are learnable.
  2. The encoder in the VAE is a normal VAE encoder, the weights have nothing Bayesian about it other than what is Bayesian in a VAE. In other words, there are no priors over the weights of the neural network. In other words it should be doing already just that, learning z_loc and z_scale.
1 Like

Thanks a lot for the reply! When I register the encoder and run pyro.get_param_store().keys() during training, I get:

dict_keys(['encoder$$$fc1.weight', 'encoder$$$fc1.bias', 'encoder$$$fc21.weight', 'encoder$$$fc21.bias', 'encoder$$$fc22.weight', 'encoder$$$fc22.bias', 'decoder$$$fc1.weight', 'decoder$$$fc1.bias', 'decoder$$$fc21.weight', 'decoder$$$fc21.bias'])

Hence, the weights in the encoder actually don’t seem to be Bayesian (otherwise there should be a scale and weight for each parameter).

Is it possible to make the neural network weights Bayesian, i.e. estimate loc and scale for encoder$$$fc1.weight, encoder$$fc1.bias, ... with a (small) hack?

Not really. As in the tutorial you have p(x, z) and you want to obtain p(z|x) which you approximate with q_\phi(z|x), the encoder. By definition, the encoder shouldn’t be Bayesian because it is used to approximate the posterior distribution. In other words, \phi can’t have a distribution. However, what you could do is impose a distribution over the weights and have \phi parameterize that. This should be doable by writing down by hand the layers of the neural network but I am not sure how to do it a simple way.

1 Like

Thanks!