Put class into random module prior

normal plain function is as following:

class PlanFlow(nn.Module):
def init(self, params, n=1):
super(PlanFlow, self).init()
plf_dist = dist.PlanarFlow(int(params.size()[1]))
self.layer = nn.ModuleList([copy.deepcopy(plf_dist) for _ in range(n)])

def forward(self, params):
    base_dist = dist.Normal(torch.zeros_like(params), torch.ones_like(params))
    plf_dist = dist.TransformedDistribution(base_dist, [each for each in self.layer])
    return plf_dist.to_event(1)

I want to put this normal plan distribution into model.guide random_module’s prior dictionary.

class BertBNN(nn.Module):
def init(self, config, num_labels, flow=1):
super(BertBNN, self).init()
self.embedding = embedding(config)
self.net = model(config, num_labels)
self.sigmoid = nn.Sigmoid()
self.flow = flow
self.netFlow = PlanFlow(params=self.net.encoder.dnse.weight, n=self.flow)

def model(self, x, y, **kwargs):
    netflow = standard_normal_prior(self.net.encoder.dense.weight)
    prior = {
        'encoder.dense.weight': netflow,
    }

    lifted_module = pyro.random_module('module', self.net, prior)
    lifted_reg_model = lifted_module()

    pyro.module("embedding", self.embedding)

    with pyro.plate('observe_data'):
        embedding_output = self.embedding(x)
        logits = lifted_reg_model(embedding_output, attention_mask)
        p_hat = self.sigmoid(logits)
        pyro.sample("obs", Bernoulli(probs=p_hat), obs=y)
    return logits

def guide(self, x, y, **kwargs):
    netflow = self.netFlow(self.net.encoder.dense.weight)


    prior = {
        'encoder.dense.weight': netflow,
    }


    lifted_module = pyro.random_module("module", self.net, prior)
    return lifted_module()

if I train the network in this way, it feels like the parameters is not load into the GPUs for training , because I know how much memory should be used in the normal neural network, and for the normal flow network, the memory is too small.

so I am wondering how to bring the normal flow framework into the random_module prior function

random_module takes pyro distributions, so you would have to implement your flow as a torch distribution object. take a look at these flows for an example.

1 Like