Variational NMF

Hi at all,

I am trying to come up with a variational version of non negative matrix factorization (NMF) which I am trying to implement as follows.

Xij | (SE)ij ~ Poisson((SE)ij)
Sij ~ Normal(0,1)
Eij ~ Normal(0,1)

where dim(X) = (p x n) and dim(S) = (p x r) and dim(E) = (r x n) and r << n.

I started with some code that I found in the forum; I use the softmax and exponential function to ensure the non-negativity constraint, the rest should be straight forward.

p = 96
r = 4
n = 2778

softmax = torch.nn.Softmax(dim=0)

def model(data):
    s_mean0 = torch.zeros([p, r])
    s_std0 = torch.ones([p, r])

    e_mean0 = torch.zeros([r, n])
    e_std0 = torch.ones([r, n])

    s = pyro.sample("s", pyro.distributions.Normal(loc = s_mean0, scale = s_std0))
    e = pyro.sample("e", pyro.distributions.Normal(loc = e_mean0, scale = e_std0))
    expectation = torch.matmul(softmax(s), torch.exp(e))    
    
    pyro.sample("obs", pyro.distributions.Poisson(expectation), obs=data)

def guide(data):
    qs_mean = pyro.param("qs_mean", torch.zeros([p, r]))
    qs_stddv = pyro.param("qw_stddv", torch.ones([p, r]), constraint=constraints.positive)
    
    qe_mean = pyro.param("qe_mean", torch.zeros([r, n]))
    qe_stddv = pyro.param("qe_stddv", torch.ones([r, n]), constraint=constraints.positive)
    
    s = pyro.sample("s", pyro.distributions.Normal(loc = qs_mean, scale = qs_stddv))
    e = pyro.sample("e", pyro.distributions.Normal(loc = qe_mean, scale = qe_stddv))
    
    expectation = torch.matmul(softmax(s), torch.exp(e))

adam_params = {"lr": 0.0005}
optimizer = Adam(adam_params)

svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 2000

# do gradient steps
for step in range(n_steps):
    print(step)
    svi.step(data)
    if step % 100 == 0:
        print('.', end='')

However, when I try to run the code I receive the following error, and coudn’t work out why. Can anyone help out?

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-53-830112326fba> in <module>()
     41 for step in range(n_steps):
     42     print(step)
---> 43     svi.step(X)
     44     if step % 100 == 0:
     45         print('.', end='')

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
    123         # get loss and compute gradients
    124         with poutine.trace(param_only=True) as param_capture:
--> 125             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    126 
    127         params = set(site["value"].unconstrained()

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    121         loss = 0.0
    122         # grab a trace from the generator
--> 123         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    124             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
    125             loss += loss_particle / self.num_particles

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
    165         else:
    166             for i in range(self.num_particles):
--> 167                 yield self._get_trace(model, guide, args, kwargs)

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
     48         """
     49         model_trace, guide_trace = get_importance_trace(
---> 50             "flat", self.max_plate_nesting, model, guide, args, kwargs)
     51         if is_validation_enabled():
     52             check_if_enumerated(guide_trace)

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     55         for site in model_trace.nodes.values():
     56             if site["type"] == "sample":
---> 57                 check_site_shape(site, max_plate_nesting)
     58         for site in guide_trace.nodes.values():
     59             if site["type"] == "sample":

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/util.py in check_site_shape(site, max_plate_nesting)
    283                 '- enclose the batched tensor in a with plate(...): context',
    284                 '- .to_event(...) the distribution being sampled',
--> 285                 '- .permute() data dimensions']))
    286 
    287     # Check parallel dimensions on the left of max_plate_nesting.

ValueError: at site "s", invalid log_prob shape
  Expected [], actual [96, 4]
  Try one of the following fixes:
  - enclose the batched tensor in a with plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

@sagar I think that you missed some declarations about dependent/independent dimensions of latent variables s and e. A simple solution is to add .to_event() at the end of Normal(...).

Hi fehiepsi, thank you very much for your help. I changed the code as follows which I believe is right, but to ensure that we’re on the same page: I want (p x r) draws from (p x r) independent normals in s and likewise (r x n) draws from (r x n) independent normals in e. Then I multiply s x e to obtain the expectation (p x n) of my poisson likelihood which I’d like to fit to my (p x n) sized data matrix. In the end I’d like to have variational distributions on s and e.

pyro.clear_param_store()
p = 96
r = 4
n = 2778

softmax = torch.nn.Softmax(dim=0)

def model(data):
    s_mean0 = torch.zeros([p, r])
    s_std0 = torch.ones([p, r])

    e_mean0 = torch.zeros([r, n])
    e_std0 = torch.ones([r, n])

    s = pyro.sample("s", pyro.distributions.Normal(loc = s_mean0, scale = s_std0).to_event(0))
    e = pyro.sample("e", pyro.distributions.Normal(loc = e_mean0, scale = e_std0).to_event(0))
    expectation = torch.matmul(softmax(s), torch.exp(e))
    
    pyro.sample("obs", pyro.distributions.Poisson(expectation).to_event(0), obs=data)

def guide(data):
    qs_mean = pyro.param("qs_mean", torch.zeros([p, r]))
    qs_stddv = pyro.param("qs_stddv", torch.ones([p, r]), constraint=constraints.positive)
    
    qe_mean = pyro.param("qe_mean", torch.zeros([r, n]))
    qe_stddv = pyro.param("qe_stddv", torch.ones([r, n]), constraint=constraints.positive)
    
    s = pyro.sample("s", pyro.distributions.Normal(loc = qs_mean, scale = qs_stddv).to_event(0))
    e = pyro.sample("e", pyro.distributions.Normal(loc = qe_mean, scale = qe_stddv).to_event(0))
    
    expectation = torch.matmul(softmax(s), torch.exp(e))

adam_params = {"lr": 0.0005}
optimizer = Adam(adam_params)

svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

n_steps = 2000

# do gradient steps
for step in range(n_steps):
    print(step)
    svi.step(X)
    if step % 100 == 0:
        print('.', end='')

Running the modified code however still tells me that there is something wrong with the shapes:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-58-82c0c135f397> in <module>()
     41 for step in range(n_steps):
     42     print(step)
---> 43     svi.step(X)
     44     if step % 100 == 0:
     45         print('.', end='')

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
    123         # get loss and compute gradients
    124         with poutine.trace(param_only=True) as param_capture:
--> 125             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    126 
    127         params = set(site["value"].unconstrained()

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    121         loss = 0.0
    122         # grab a trace from the generator
--> 123         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    124             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
    125             loss += loss_particle / self.num_particles

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
    165         else:
    166             for i in range(self.num_particles):
--> 167                 yield self._get_trace(model, guide, args, kwargs)

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
     48         """
     49         model_trace, guide_trace = get_importance_trace(
---> 50             "flat", self.max_plate_nesting, model, guide, args, kwargs)
     51         if is_validation_enabled():
     52             check_if_enumerated(guide_trace)

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     55         for site in model_trace.nodes.values():
     56             if site["type"] == "sample":
---> 57                 check_site_shape(site, max_plate_nesting)
     58         for site in guide_trace.nodes.values():
     59             if site["type"] == "sample":

/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/util.py in check_site_shape(site, max_plate_nesting)
    283                 '- enclose the batched tensor in a with plate(...): context',
    284                 '- .to_event(...) the distribution being sampled',
--> 285                 '- .permute() data dimensions']))
    286 
    287     # Check parallel dimensions on the left of max_plate_nesting.

ValueError: at site "s", invalid log_prob shape
  Expected [], actual [96, 4]
  Try one of the following fixes:
  - enclose the batched tensor in a with plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

Any ideas ?

Hi @sagar, I think you want to use one of these (they are equivalent in Pyro):

  • to_event(2)
  • to_event()
  • use two pyro.plate statements

More details are explained in the link to tensor shape tutorial in my last comment. I think to_event(0) would have no effect. The above ways will give you independent samples, so don’t worry about it.