How to set evaluation mode for a normalization flow with batchnorm layer

Hi,

Suppose I create a normalization flow using

import pyro.distributions as dist
targetdist = dist.TransformedDistribution(basedist, realnvp)

where basedist is a prespecified base distribution and realnvp is a list of transforms including multiple AffineCoupling layers, BatchNorm layers and Permute layers, then what is the appropriate way to set the BatchNorm layer to its .eval() mode when sampling using targetdist.sample()?

Thanks,

i think you should be able to do the following:

for t in targetdist.transforms:
    if hasattr(t, 'eval'):
        t.eval()
1 Like

Sure thanks!