Kernel dies when fitting factor analysis model - possible memory issue?

Hi all,

I’m attempting to fit a simple factor analysis model using pyro:

def fa_model(data=None, args=None, batch_size=None):
    K,P,N = args['K'], args['P'], args['N']

    w = pyro.sample("w", dist.Normal(torch.randn(size=[K,P]), 1.0).to_event(2))
    z = pyro.sample("z", dist.Normal(torch.randn(size=[N,K]), 1.0).to_event(2))

    mean = z@w

    pyro.sample("obs", dist.Normal(mean, 1.0).to_event(2), obs=data)

With code for inference following the tutorial:

args = {'K': 2, 'N': X.shape[0], 'P': X.shape[1]}

pyro.clear_param_store()
auto_guide = pyro.infer.autoguide.AutoNormal(fa_model)
adam = pyro.optim.Adam({"lr": 0.02})  # Consider decreasing learning rate.
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(fa_model, auto_guide, adam, elbo)

losses = []
for step in range(1000):  # Consider running for more steps.
    loss = svi.step(torch.tensor(X), args)
    losses.append(loss)
    if step % 100 == 0:
        logging.info("Elbo loss: {}".format(loss))

However, when fitting this with an X of dimension [200,52] (i.e. N=200,P=52) with K=2, the kernel dies. I suspect this is a memory issue as subsampling X down to N=20 (X = X[0:20,:]) performs inference with no issues. It doesn’t seem to be an issue with actual machine memory, as the matrix multiplication in torch completes no problem:

ww = torch.randn(size=[args['K'],args['P']])
zz = torch.randn(size=[args['N'],args['K']])
(zz @ ww).shape # Returns 200x52 as expected

Any pointers here would be much appreciated. I’m familiar with PPLs like Stan but this is first time using Pyro and I’m wondering if it’s an issue of not using plate notation or similar.

Thanks

i don’t know what your issue is but it’s unlikely that you’ve specified the model correctly. as written you’ll randomly draw a random location for each normal prior whenever the model is invoked. you probably want something like

w = pyro.sample("w", dist.Normal(torch.zeros(K, P), 1.0).to_event(2))

also “kernel dies” is not a helpful/informative amount of information.

Thanks, I have updated to reflect this but the error persists.

This is the entirety of the error:


Is there a relevant log from pyro to check?

i suggest running your code in a script. i have no idea how exception handling etc is handled in whatever environment (jupyter notebook; something else?) that you’re using. that way you are much more likely to see informative stack traces and not silly pop up boxes.

Hi @martinjankowiak
Thanks for the pointer. Eventual error was

There appear to be 1 leaked semaphore objects to clean up at shutdown

which is referenced in this github issue as a conda issue. I’ll try some of the suggested fixes and report back in case any other pyro users experience the same issue.