"Not Implemented Error' for "Relaxed Bernoulli Straight Through" with "Auto Hierarchical Normal Messenger"

Hi,

I get a "Not Implemented Error’ during training, when I try to use a “Relaxed Bernoulli Straight Through” or “Relaxed Bernoulli” in model with a “Auto Hierarchical Normal Messenger” autoguide. Is this expected? Are there plans to make this distribution compatible with this autoguide or are there some fundamental limitations that prevent this?

Thanks for the info!

Alexander

PS: I had to add extra spaces, because otherwise I could not make a post due to an error saying “one or more of your words are too long”…

Hi @AlexanderA, could you paste your error message? I’m not sure why these would fail.

Thanks for the reply, here is the error message:


NotImplementedError Traceback (most recent call last)
Input In [7], in <cell line: 7>()
20 svi = SVI(model_1, global_guide, optim, loss=elbo)
21 for i in range(n_iterations):
—> 22 loss = svi.step(data, n_components, n_observations, temp = temperature_array[k])
24 posterior = Predictive(model_1, guide=global_guide, num_samples=2)(data, n_components, n_observations)
26 score[k,j] = torch.sum(torch.abs(torch.mean(posterior[‘Bernoulli’], axis = 0) - activation))/len(activation)

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
→ 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site[“value”].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven’t been seen yet

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
138 loss = 0.0
139 # grab a trace from the generator
→ 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
142 model_trace, guide_trace
143 )
144 loss += loss_particle / self.num_particles

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/elbo.py:182, in ELBO._get_traces(self, model, guide, args, kwargs)
180 else:
181 for i in range(self.num_particles):
→ 182 yield self._get_trace(model, guide, args, kwargs)

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53 “”"
54 Returns a single trace from the guide, and the model that is run
55 against it.
56 “”"
—> 57 model_trace, guide_trace = get_importance_trace(
58 “flat”, self.max_plate_nesting, model, guide, args, kwargs
59 )
60 if is_validation_enabled():
61 check_if_enumerated(guide_trace)

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/enum.py:57, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
55 if detach:
56 raise NotImplementedError(“GuideMessenger does not support detach”)
—> 57 guide(*args, **kwargs)
58 model_trace, guide_trace = unwrapped_guide.get_traces()
59 else:

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/nn/module.py:637, in pyro_method..cached_fn(self, *args, **kwargs)
634 @functools.wraps(fn)
635 def cached_fn(self, *args, **kwargs):
636 with self._pyro_context:
→ 637 return fn(self, *args, **kwargs)

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:46, in AutoMessenger.call(self, *args, **kwargs)
44 self._outer_plates = tuple(f.name for f in get_plates())
45 try:
—> 46 return super().call(*args, **kwargs)
47 finally:
48 del self._outer_plates

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/poutine/guide.py:45, in GuideMessenger.call(self, *args, **kwargs)
43 try:
44 with self:
—> 45 self.model(*args, **kwargs)
46 finally:
47 del self.args_kwargs

Input In [6], in model_1(data, n_components, n_observations, temp)
3 ps = pyro.sample(‘ps’, dist.Dirichlet(torch.ones(n_components)/10.))
4 mus = pyro.sample(‘mus’, dist.Gamma(20,2).expand([n_components]).to_event(1))
----> 5 activation = pyro.sample(‘Bernoulli’,
6 dist.RelaxedBernoulliStraightThrough(probs = ps[0],
7 temperature = torch.tensor(temp)).expand([n_observations]).to_event(1))
9 mean = mus[0] + mus[1]*activation
11 pyro.sample(“data_target”, dist.Normal(loc = mean, scale = torch.tensor(1.)).to_event(1), obs = data)

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/primitives.py:163, in sample(name, fn, *args, **kwargs)
146 msg = {
147 “type”: “sample”,
148 “name”: name,
(…)
160 “continuation”: None,
161 }
162 # apply the stack and return its return value
→ 163 apply_stack(msg)
164 return msg[“value”]

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/poutine/runtime.py:213, in apply_stack(initial_msg)
209 for frame in reversed(stack):
211 pointer = pointer + 1
→ 213 frame._process_message(msg)
215 if msg[“stop”]:
216 break

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/poutine/messenger.py:154, in Messenger._process_message(self, msg)
152 method = getattr(self, “pyro{}”.format(msg[“type”]), None)
153 if method is not None:
→ 154 return method(msg)
155 return None

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/poutine/guide.py:62, in GuideMessenger._pyro_sample(self, msg)
60 prior = msg[“fn”]
61 msg[“infer”][“prior”] = prior
—> 62 posterior = self.get_posterior(msg[“name”], prior)
63 if isinstance(posterior, torch.Tensor):
64 posterior = dist.Delta(posterior, event_dim=prior.event_dim)

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:277, in AutoHierarchicalNormalMessenger.get_posterior(self, name, prior)
274 transform = biject_to(prior.support)
275 if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
276 # If hierarchical_sites not specified all sites are assumed to be hierarchical
→ 277 loc, scale, weight = self._get_params(name, prior)
278 loc = loc + transform.inv(prior.mean) * weight
279 posterior = dist.TransformedDistribution(
280 dist.Normal(loc, scale).to_event(transform.domain.event_dim),
281 transform.with_cache(),
282 )

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:316, in AutoHierarchicalNormalMessenger._get_params(self, name, prior)
314 # if site is hierarchical substract contribution of dependencies from init_loc
315 if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
→ 316 init_prior_mean = transform.inv(prior.mean)
317 init_prior_mean = self._adjust_plates(init_prior_mean, event_dim)
318 init_loc = init_loc - init_weight * init_prior_mean

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/distributions/independent.py:78, in Independent.mean(self)
76 @property
77 def mean(self):
—> 78 return self.base_dist.mean

File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/distributions/distribution.py:122, in Distribution.mean(self)
117 @property
118 def mean(self):
119 “”"
120 Returns the mean of the distribution.
121 “”"
→ 122 raise NotImplementedError

NotImplementedError:

I have also made a minimal example here.

It looks like your RelaxedBernoulliStraighThrough distribution doesn’t support the .mean method. You could try to implement a patch in your codebase (until we implement it upstream):

RelaxedBernoulliStraightThrough.mean = property(lambda self: self.probs)

Does that work for you?

You could also switch to AutoNormalMessenger which I believe doesn’t use .mean.

Hi,

in my actual project I would like to use a RelaxedCategoricalStraightThrough (so not Bernoulli). What would be the patch in the codebase for that distribution?

Thanks!

Alexander

Same patch should work

Thanks! That worked indeed.