I am trying to do a simple inference problem but with two priors (mu and sigma):
import torch
import pyro
def measurements():
mu = pyro.sample('mu', pyro.distributions.Normal(0,50))
sigma = pyro.sample('sigma', pyro.distributions.Uniform(0,25))
x = pyro.sample('x', pyro.distributions.Normal(mu,sigma))
return x
conditioned_measurements = pyro.condition(measurements, data={"x": -27.020})
from torch.distributions import constraints
def measurements_parametrized_guide_constrained():
a = pyro.param("a", torch.tensor(0.))
b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive)
c = pyro.param("a", torch.tensor(0.), constraint=constraints.positive)
d = pyro.param("b", torch.tensor(1.), constraint=constraints.positive)
return pyro.sample("mu", pyro.distributions.Normal(a, b)),pyro.sample("sigma", pyro.distributions.Normal(c, d))
pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_measurements,
guide=measurements_parametrized_guide_constrained,
optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1}),
loss=pyro.infer.Trace_ELBO())
losses, a,b,c,d = [], [], [], [], []
num_steps = 2500
for t in range(num_steps):
losses.append(svi.step())
a.append(pyro.param("a").item())
b.append(pyro.param("b").item())
c.append(pyro.param("c").item())
d.append(pyro.param("d").item())
And I got the following error:
ValueError: Error while computing log_prob at site 'x':
The value argument to log_prob must be a Tensor
Trace Shapes:
Param Sites:
Sample Sites:
mu dist |
value |
log_prob |
sigma dist |
value |
log_prob |
x dist |
value |
The part I’m not sure is am I returning the two posteriors from the guide correctly but it seems the error is about something else…
Update: I have corrected the some bugs identified by the replies and myself as the following:
import torch
import pyro
def measurements():
mu = pyro.sample('mu', pyro.distributions.Normal(0,50))
sigma = pyro.sample('sigma', pyro.distributions.Uniform(0.001,25))
x = pyro.sample('x', pyro.distributions.Normal(mu,sigma))
return x
conditioned_measurements = pyro.condition(measurements, data={"x": torch.tensor(-27.020)})
from torch.distributions import constraints
def measurements_parametrized_guide_constrained():
a = pyro.param("a", torch.tensor(0.))
b = pyro.param("b", torch.tensor(50.), constraint=constraints.positive)
c = pyro.param("c", torch.tensor(0.001), constraint=constraints.positive)
d = pyro.param("d", torch.tensor(25.), constraint=constraints.positive)
return pyro.sample("mu", pyro.distributions.Normal(a, b)),pyro.sample("sigma", pyro.distributions.Uniform(c, d))
pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_measurements,
guide=measurements_parametrized_guide_constrained,
optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1}),
loss=pyro.infer.Trace_ELBO())
losses, a,b,c,d = [], [], [], [], []
num_steps = 2500
for t in range(num_steps):
losses.append(svi.step())
a.append(pyro.param("a").item())
b.append(pyro.param("b").item())
c.append(pyro.param("c").item())
d.append(pyro.param("d").item())
unfortunately there is still error but different from before:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
229 try:
--> 230 log_p = site["fn"].log_prob(
231 site["value"], *site["args"], **site["kwargs"]
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/uniform.py in log_prob(self, value)
72 if self._validate_args:
---> 73 self._validate_sample(value)
74 lb = self.low.le(value).type_as(self.low)
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
276 if not support.check(value).all():
--> 277 raise ValueError('The value argument must be within the support')
278
ValueError: The value argument must be within the support
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
<ipython-input-122-78d4c23ea3f5> in <module>
33
34 for t in range(num_steps):
---> 35 losses.append(svi.step())
36 a.append(pyro.param("a").item())
37 b.append(pyro.param("b").item())
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
146
147 params = set(
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
138 loss = 0.0
139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
142 model_trace, guide_trace
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
180 else:
181 for i in range(self.num_particles):
--> 182 yield self._get_trace(model, guide, args, kwargs)
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
55 against it.
56 """
---> 57 model_trace, guide_trace = get_importance_trace(
58 "flat", self.max_plate_nesting, model, guide, args, kwargs
59 )
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
73 model_trace = prune_subsample_sites(model_trace)
74
---> 75 model_trace.compute_log_prob()
76 guide_trace.compute_score_parts()
77 if is_validation_enabled():
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
234 _, exc_value, traceback = sys.exc_info()
235 shapes = self.format_shapes(last_site=site["name"])
--> 236 raise ValueError(
237 "Error while computing log_prob at site '{}':\n{}\n{}".format(
238 name, exc_value, shapes
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
228 if "log_prob" not in site:
229 try:
--> 230 log_p = site["fn"].log_prob(
231 site["value"], *site["args"], **site["kwargs"]
232 )
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/uniform.py in log_prob(self, value)
71 def log_prob(self, value):
72 if self._validate_args:
---> 73 self._validate_sample(value)
74 lb = self.low.le(value).type_as(self.low)
75 ub = self.high.gt(value).type_as(self.low)
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
275 assert support is not None
276 if not support.check(value).all():
--> 277 raise ValueError('The value argument must be within the support')
278
279 def _get_checked_instance(self, cls, _instance=None):
ValueError: Error while computing log_prob at site 'sigma':
The value argument must be within the support
Trace Shapes:
Param Sites:
Sample Sites:
mu dist |
value |
log_prob |
sigma dist |
value |