Can I avoid memory usage growing with number of observation?

Hello all,

I am still building my understanding of Pyro and probabilistic programming more general so please excuse the dumb question.

I notice that memory required grows linearly with number of data points when I use parallel plate, as in model1 below. This makes sense. But the same thing seems to happen when I try to split data into slices and have a separate plate for each slice, see model2.

Is there any way to avoid memory requirement growing linearly with number of observations in MCMC? I suspect the answer is no, but I would appreciate if you could confirm and explain why that i not possible, or point to a relevant reading.

I understand I could do subsampling in SVI but, unfortunately, SVI is not suitable for my problem.

Many thanks for any help!

def model1(W,Z):
  N = W.shape[0]
  beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones([Z]) * 0.1))
  with numpyro.plate("W", N):
    numpyro.sample("obs", dist.CategoricalProbs(beta), obs = W)
def model2(W,Z, sliceLength):
  N = W.shape[0]
  beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones([Z]) * 0.1))

  S = int(N / sliceLength)
  for s in range(S):
    Wslice = W[(s*sliceLength):((s+1)*sliceLength)]
    with numpyro.plate("W_{}".format(s), sliceLength):
      numpyro.sample("obs_{}".format(s), dist.CategoricalProbs(beta), obs = Wslice)

Hi @Elchorro, the memory grows even when you split the data up because Pyro’s inference algorithms first build an autograd graph for the entire model execution, then backprop through that graph after model execution; thus all your dist.CategoricalProbs(Beta) are instantiated.

One way to try to reduce memory in your model might be to move the dist.CategoricalProbs(beta) and the broadcast out of the loop, e.g.

def model3(W,Z, sliceLength):
  N = W.shape[0]
  beta = numpyro.sample("beta", dist.Dirichlet(jnp.ones([Z]) * 0.1))

  S = int(N / sliceLength)
  cat_dist = dist.CategoricalProbs(beta).expand((sliceLength,))
  for s in range(S):
    Wslice = W[(s*sliceLength):((s+1)*sliceLength)]
    with numpyro.plate("W_{}".format(s), sliceLength):
      numpyro.sample("obs_{}".format(s), cat_dist, obs = Wslice)

but memory will still grow linearly with data size.

An alternative that would take constant memory would be to convert your Categorical observation into a Multinomial observation, something like

counts = jnp.bincount(W, minlength=Z)
...
numpyro.sample("counts", dist.Multinomial(beta, total_count=Z),
               obs=counts)

Hi @fritzo

Thank you for the suggestions. This is useful. I suspected the memory issue was coming from autograd but than you for confirming.

I understand Multinomial should be constant memory, I will look if I can convert my problem (which is a bit more involved than model1 and model2).

The only thing I am not too sure is why model3 is better than model2 memory-wise. How should I think of it? (I understand why it’s better for other reasons, though).

Thank you for your time.

why model3 is better than model2 memory-wise

model3 is only a little better because a single expanded Categorical distribution is created and shared among all loop iterations, rather than each loop iteration creating a new Categorical. Note those temporary Categorical distributions cannot be garbage collected until after model execution, because Pyro stores them in its trace data structure.

1 Like