I am trying to come up with a variational version of non negative matrix factorization (NMF) which I am trying to implement as follows.
Xij | (SE)ij ~ Poisson((SE)ij)
Sij ~ Normal(0,1)
Eij ~ Normal(0,1)
where dim(X) = (p x n) and dim(S) = (p x r) and dim(E) = (r x n) and r << n.
I started with some code that I found in the forum; I use the softmax and exponential function to ensure the non-negativity constraint, the rest should be straight forward.
p = 96
r = 4
n = 2778
softmax = torch.nn.Softmax(dim=0)
def model(data):
s_mean0 = torch.zeros([p, r])
s_std0 = torch.ones([p, r])
e_mean0 = torch.zeros([r, n])
e_std0 = torch.ones([r, n])
s = pyro.sample("s", pyro.distributions.Normal(loc = s_mean0, scale = s_std0))
e = pyro.sample("e", pyro.distributions.Normal(loc = e_mean0, scale = e_std0))
expectation = torch.matmul(softmax(s), torch.exp(e))
pyro.sample("obs", pyro.distributions.Poisson(expectation), obs=data)
def guide(data):
qs_mean = pyro.param("qs_mean", torch.zeros([p, r]))
qs_stddv = pyro.param("qw_stddv", torch.ones([p, r]), constraint=constraints.positive)
qe_mean = pyro.param("qe_mean", torch.zeros([r, n]))
qe_stddv = pyro.param("qe_stddv", torch.ones([r, n]), constraint=constraints.positive)
s = pyro.sample("s", pyro.distributions.Normal(loc = qs_mean, scale = qs_stddv))
e = pyro.sample("e", pyro.distributions.Normal(loc = qe_mean, scale = qe_stddv))
expectation = torch.matmul(softmax(s), torch.exp(e))
adam_params = {"lr": 0.0005}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
n_steps = 2000
# do gradient steps
for step in range(n_steps):
if step % 100 == 0:
print('.', end='')
However, when I try to run the code I receive the following error, and coudn’t work out why. Can anyone help out?
ValueError Traceback (most recent call last)
<ipython-input-53-830112326fba> in <module>()
41 for step in range(n_steps):
42 print(step)
---> 43 svi.step(X)
44 if step % 100 == 0:
45 print('.', end='')
/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
123 # get loss and compute gradients
124 with poutine.trace(param_only=True) as param_capture:
--> 125 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
127 params = set(site["value"].unconstrained()
/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
121 loss = 0.0
122 # grab a trace from the generator
--> 123 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
124 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
125 loss += loss_particle / self.num_particles
/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
165 else:
166 for i in range(self.num_particles):
--> 167 yield self._get_trace(model, guide, args, kwargs)
/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
48 """
49 model_trace, guide_trace = get_importance_trace(
---> 50 "flat", self.max_plate_nesting, model, guide, args, kwargs)
51 if is_validation_enabled():
52 check_if_enumerated(guide_trace)
/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
55 for site in model_trace.nodes.values():
56 if site["type"] == "sample":
---> 57 check_site_shape(site, max_plate_nesting)
58 for site in guide_trace.nodes.values():
59 if site["type"] == "sample":
/hps/nobackup/research/sds-pawg/gerstung/harald/tfgpu-1.10/lib/python3.6/site-packages/pyro/util.py in check_site_shape(site, max_plate_nesting)
283 '- enclose the batched tensor in a with plate(...): context',
284 '- .to_event(...) the distribution being sampled',
--> 285 '- .permute() data dimensions']))
287 # Check parallel dimensions on the left of max_plate_nesting.
ValueError: at site "s", invalid log_prob shape
Expected [], actual [96, 4]
Try one of the following fixes:
- enclose the batched tensor in a with plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions