Hi,
Suppose I create a normalization flow using
import pyro.distributions as dist
targetdist = dist.TransformedDistribution(basedist, realnvp)
where basedist
is a prespecified base distribution and realnvp
is a list of transforms including multiple AffineCoupling
layers, BatchNorm
layers and Permute
layers, then what is the appropriate way to set the BatchNorm
layer to its .eval()
mode when sampling using targetdist.sample()
?
Thanks,