Hi all,
I’m attempting to fit a simple factor analysis model using pyro:
def fa_model(data=None, args=None, batch_size=None):
K,P,N = args['K'], args['P'], args['N']
w = pyro.sample("w", dist.Normal(torch.randn(size=[K,P]), 1.0).to_event(2))
z = pyro.sample("z", dist.Normal(torch.randn(size=[N,K]), 1.0).to_event(2))
mean = z@w
pyro.sample("obs", dist.Normal(mean, 1.0).to_event(2), obs=data)
With code for inference following the tutorial:
args = {'K': 2, 'N': X.shape[0], 'P': X.shape[1]}
pyro.clear_param_store()
auto_guide = pyro.infer.autoguide.AutoNormal(fa_model)
adam = pyro.optim.Adam({"lr": 0.02}) # Consider decreasing learning rate.
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(fa_model, auto_guide, adam, elbo)
losses = []
for step in range(1000): # Consider running for more steps.
loss = svi.step(torch.tensor(X), args)
losses.append(loss)
if step % 100 == 0:
logging.info("Elbo loss: {}".format(loss))
However, when fitting this with an X of dimension [200,52] (i.e. N=200,P=52) with K=2, the kernel dies. I suspect this is a memory issue as subsampling X down to N=20 (X = X[0:20,:]
) performs inference with no issues. It doesn’t seem to be an issue with actual machine memory, as the matrix multiplication in torch completes no problem:
ww = torch.randn(size=[args['K'],args['P']])
zz = torch.randn(size=[args['N'],args['K']])
(zz @ ww).shape # Returns 200x52 as expected
Any pointers here would be much appreciated. I’m familiar with PPLs like Stan but this is first time using Pyro and I’m wondering if it’s an issue of not using plate notation or similar.
Thanks