Hi all,
I made a very simple dataset, with input z,x and output is either class 1 or 0.
To keep the problem easy, I used x is always equal to 1. Then z is sampled from a normal dis (0,1) and if z < 0, output class 1 and if z > 0, output class 0.
The model and guide I used are below (taken from an IBM example and altered)
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.l1 = torch.nn.Linear(nx, nh)
# self.l1a = torch.nn.Linear(nh, nh)
self.l2 = torch.nn.Linear(nh, ny)
self.relu = torch.nn.ReLU()
def forward(self, x):
h = self.relu(self.l1(torch.Tensor(x).view((-1, nx))))
# h = self.relu(self.l1(h.view((-1, nh))))
yhat = self.l2(h)
return yhat
mlp = MLP().to(device)
# Model
def normal(*shape):
loc = torch.zeros(*shape).to(device)
scale = torch.ones(*shape).to(device)
return Normal(loc, scale)
def model(imgs, lbls):
priors = {
'l1.weight': normal(nh, nx), 'l1.bias': normal(nh),
'l2.weight': normal(ny, nh), 'l2.bias': normal(ny)}
lifted_module = pyro.random_module("mlp", mlp, priors)
lifted_reg_model = lifted_module()
lhat = log_softmax(lifted_reg_model(imgs))
pyro.sample("obs", Categorical(logits=lhat), obs=lbls)
# Inference Guide
def vnormal(name, *shape):
loc = pyro.param(name+"m", torch.randn(*shape, requires_grad=True, device=device))
scale = pyro.param(name+"s", torch.randn(*shape, requires_grad=True, device=device))
return Normal(loc, softplus(scale))
def guide(imgs, lbls):
dists = {
'l1.weight': vnormal("W1", nh, nx), 'l1.bias': vnormal("b1", nh),
'l2.weight': vnormal("W2", ny, nh), 'l2.bias':vnormal("b2", ny)}
lifted_module = pyro.random_module("mlp", mlp, dists)
return lifted_module()
inference = SVI(model, guide, Adam({"lr": 0.001}), loss=Trace_ELBO())
Testing this on minibatches of 20, I seem to achieve poorer performance as the number of epochs goes up.
Could anybody explain why this is?