Thanks for the reply, here is the error message:
NotImplementedError Traceback (most recent call last)
Input In [7], in <cell line: 7>()
20 svi = SVI(model_1, global_guide, optim, loss=elbo)
21 for i in range(n_iterations):
—> 22 loss = svi.step(data, n_components, n_observations, temp = temperature_array[k])
24 posterior = Predictive(model_1, guide=global_guide, num_samples=2)(data, n_components, n_observations)
26 score[k,j] = torch.sum(torch.abs(torch.mean(posterior[‘Bernoulli’], axis = 0) - activation))/len(activation)
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
→ 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site[“value”].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven’t been seen yet
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
138 loss = 0.0
139 # grab a trace from the generator
→ 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
142 model_trace, guide_trace
143 )
144 loss += loss_particle / self.num_particles
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/elbo.py:182, in ELBO._get_traces(self, model, guide, args, kwargs)
180 else:
181 for i in range(self.num_particles):
→ 182 yield self._get_trace(model, guide, args, kwargs)
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53 “”"
54 Returns a single trace from the guide, and the model that is run
55 against it.
56 “”"
—> 57 model_trace, guide_trace = get_importance_trace(
58 “flat”, self.max_plate_nesting, model, guide, args, kwargs
59 )
60 if is_validation_enabled():
61 check_if_enumerated(guide_trace)
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/enum.py:57, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
55 if detach:
56 raise NotImplementedError(“GuideMessenger does not support detach”)
—> 57 guide(*args, **kwargs)
58 model_trace, guide_trace = unwrapped_guide.get_traces()
59 else:
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/nn/module.py:637, in pyro_method..cached_fn(self, *args, **kwargs)
634 @functools.wraps(fn)
635 def cached_fn(self, *args, **kwargs):
636 with self._pyro_context:
→ 637 return fn(self, *args, **kwargs)
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:46, in AutoMessenger.call(self, *args, **kwargs)
44 self._outer_plates = tuple(f.name for f in get_plates())
45 try:
—> 46 return super().call(*args, **kwargs)
47 finally:
48 del self._outer_plates
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/poutine/guide.py:45, in GuideMessenger.call(self, *args, **kwargs)
43 try:
44 with self:
—> 45 self.model(*args, **kwargs)
46 finally:
47 del self.args_kwargs
Input In [6], in model_1(data, n_components, n_observations, temp)
3 ps = pyro.sample(‘ps’, dist.Dirichlet(torch.ones(n_components)/10.))
4 mus = pyro.sample(‘mus’, dist.Gamma(20,2).expand([n_components]).to_event(1))
----> 5 activation = pyro.sample(‘Bernoulli’,
6 dist.RelaxedBernoulliStraightThrough(probs = ps[0],
7 temperature = torch.tensor(temp)).expand([n_observations]).to_event(1))
9 mean = mus[0] + mus[1]*activation
11 pyro.sample(“data_target”, dist.Normal(loc = mean, scale = torch.tensor(1.)).to_event(1), obs = data)
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/primitives.py:163, in sample(name, fn, *args, **kwargs)
146 msg = {
147 “type”: “sample”,
148 “name”: name,
(…)
160 “continuation”: None,
161 }
162 # apply the stack and return its return value
→ 163 apply_stack(msg)
164 return msg[“value”]
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/poutine/runtime.py:213, in apply_stack(initial_msg)
209 for frame in reversed(stack):
211 pointer = pointer + 1
→ 213 frame._process_message(msg)
215 if msg[“stop”]:
216 break
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/poutine/messenger.py:154, in Messenger._process_message(self, msg)
152 method = getattr(self, “pyro{}”.format(msg[“type”]), None)
153 if method is not None:
→ 154 return method(msg)
155 return None
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/poutine/guide.py:62, in GuideMessenger._pyro_sample(self, msg)
60 prior = msg[“fn”]
61 msg[“infer”][“prior”] = prior
—> 62 posterior = self.get_posterior(msg[“name”], prior)
63 if isinstance(posterior, torch.Tensor):
64 posterior = dist.Delta(posterior, event_dim=prior.event_dim)
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:277, in AutoHierarchicalNormalMessenger.get_posterior(self, name, prior)
274 transform = biject_to(prior.support)
275 if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
276 # If hierarchical_sites not specified all sites are assumed to be hierarchical
→ 277 loc, scale, weight = self._get_params(name, prior)
278 loc = loc + transform.inv(prior.mean) * weight
279 posterior = dist.TransformedDistribution(
280 dist.Normal(loc, scale).to_event(transform.domain.event_dim),
281 transform.with_cache(),
282 )
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/pyro/infer/autoguide/effect.py:316, in AutoHierarchicalNormalMessenger._get_params(self, name, prior)
314 # if site is hierarchical substract contribution of dependencies from init_loc
315 if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
→ 316 init_prior_mean = transform.inv(prior.mean)
317 init_prior_mean = self._adjust_plates(init_prior_mean, event_dim)
318 init_loc = init_loc - init_weight * init_prior_mean
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/distributions/independent.py:78, in Independent.mean(self)
76 @property
77 def mean(self):
—> 78 return self.base_dist.mean
File /nfs/team283/aa16/software/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/distributions/distribution.py:122, in Distribution.mean(self)
117 @property
118 def mean(self):
119 “”"
120 Returns the mean of the distribution.
121 “”"
→ 122 raise NotImplementedError
NotImplementedError: