Hi,
I’m new to pyro and trying to understand the basics of Bayesian regression from a Bayesian Linear Regression example. (Bayesian Regression - Inference Algorithms (Part 2) — Pyro Tutorials 1.8.4 documentation)
But when I run the svi inference part, sometimes I get ValueError saying that the scale parameter of the normal distribution is invalid values. (inference code and full error message attached below)
It seems that the error occurs when a negative value is sampled for the scale parameter, which is strange because positive constraints are applied to all scale parameters.
So, my questions are
- Is there a way to avoid such a ValueError besides running the code over and over again until it works?
- How constraints on the sample work? I thought they work as “hard” constraints, so samples not satisfying the constraints are filtered before they are evaluated for the inference. But it seems like negative scale value (which violates constraints) is somehow sampled and used. So I’m misunderstanding something, but I’m not sure what it is.
def model(is_cont_africa, ruggedness, log_gdp=None):
a = pyro.sample("a", dist.Normal(0., 10.))
b_a = pyro.sample("bA", dist.Normal(0., 1.))
b_r = pyro.sample("bR", dist.Normal(0., 1.))
b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
sigma = pyro.sample("sigma", dist.Normal(0., 10.))
mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
with pyro.plate("data", len(ruggedness)):
return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)
def guide(is_cont_africa, ruggedness, log_gdp):
a_loc = pyro.param('a_loc', lambda: torch.tensor(0.))
a_scale = pyro.param('a_scale', lambda:torch.tensor(1.), constraint=constraints.positive)
b_loc = pyro.param('b_loc', lambda: torch.randn(3))
b_scale = pyro.param('b_scale', lambda: torch.ones(3), constraint=constraints.positive)
sigma_loc = pyro.param('sigma_loc', lambda: torch.tensor(1.), constraint=constraints.positive)
a = pyro.sample("a", dist.Normal(a_loc, a_scale))
b_a = pyro.sample("bA", dist.Normal(b_loc[0], b_scale[0]))
b_r = pyro.sample("bR", dist.Normal(b_loc[1], b_scale[1]))
b_ar = pyro.sample("bAR", dist.Normal(b_loc[2], b_scale[2]))
sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))
mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
num_samples = 1000
predictive = Predictive(model, guide=guide, num_samples=num_samples)
svi_samples = {k: v.reshape(num_samples).detach().cpu().numpy()
for k, v in predictive(log_gdp, is_cont_africa, ruggedness).items()
if k != "obs"}
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
26 with self.clone():
---> 27 return func(*args, **kwargs)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
Input In [251], in model(is_cont_africa, ruggedness, log_gdp)
10 with pyro.plate("data", len(ruggedness)):
---> 11 return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/distributions/distribution.py:18, in DistributionMeta.__call__(cls, *args, **kwargs)
17 return result
---> 18 return super().__call__(*args, **kwargs)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/normal.py:50, in Normal.__init__(self, loc, scale, validate_args)
49 batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
54 if not valid.all():
---> 55 raise ValueError(
56 f"Expected parameter {param} "
57 f"({type(value).__name__} of shape {tuple(value.shape)}) "
58 f"of distribution {repr(self)} "
59 f"to satisfy the constraint {repr(constraint)}, "
60 f"but found invalid values:\n{value}"
61 )
62 super(Distribution, self).__init__()
ValueError: Expected parameter scale (Tensor of shape (170,)) of distribution Normal(loc: torch.Size([170]), scale: torch.Size([170])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529])
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Input In [274], in <cell line: 5>()
3 num_samples = 1000
4 predictive = Predictive(model, guide=guide, num_samples=num_samples)
5 svi_samples = {k: v.reshape(num_samples).detach().cpu().numpy()
----> 6 for k, v in predictive(log_gdp, is_cont_africa, ruggedness).items()
7 if k != "obs"}
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110 return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs)
263 return_sites = None if not return_sites else return_sites
264 posterior_samples = _predictive(
265 self.guide,
266 posterior_samples,
(...)
271 model_kwargs=kwargs,
272 )
--> 273 return _predictive(
274 self.model,
275 posterior_samples,
276 self.num_samples,
277 return_sites=return_sites,
278 parallel=self.parallel,
279 model_args=args,
280 model_kwargs=kwargs,
281 )
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/infer/predictive.py:78, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
67 def _predictive(
68 model,
69 posterior_samples,
(...)
75 model_kwargs={},
76 ):
77 model = torch.no_grad()(poutine.mask(model, mask=False))
---> 78 max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
79 vectorize = pyro.plate(
80 "_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1
81 )
82 model_trace = prune_subsample_sites(
83 poutine.trace(model).get_trace(*model_args, **model_kwargs)
84 )
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/infer/predictive.py:21, in _guess_max_plate_nesting(model, args, kwargs)
15 """
16 Guesses max_plate_nesting by running the model once
17 without enumeration. This optimistically assumes static model
18 structure.
19 """
20 with poutine.block():
---> 21 model_trace = poutine.trace(model).get_trace(*args, **kwargs)
22 sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"]
24 dims = [
25 frame.dim
26 for site in sites
27 for frame in site["cond_indep_stack"]
28 if frame.vectorized
29 ]
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, **kwargs)
190 def get_trace(self, *args, **kwargs):
191 """
192 :returns: data structure
193 :rtype: pyro.poutine.Trace
(...)
196 Calls this poutine and returns its trace instead of the function's return value.
197 """
--> 198 self(*args, **kwargs)
199 return self.msngr.get_trace()
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.__call__(self, *args, **kwargs)
178 exc = exc_type("{}\n{}".format(exc_value, shapes))
179 exc = exc.with_traceback(traceback)
--> 180 raise exc from e
181 self.msngr.trace.add_node(
182 "_RETURN", name="_RETURN", type="return", value=ret
183 )
184 return ret
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
170 self.msngr.trace.add_node(
171 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
172 )
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176 exc_type, exc_value, traceback = sys.exc_info()
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
Input In [251], in model(is_cont_africa, ruggedness, log_gdp)
8 mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
10 with pyro.plate("data", len(ruggedness)):
---> 11 return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/distributions/distribution.py:18, in DistributionMeta.__call__(cls, *args, **kwargs)
16 if result is not None:
17 return result
---> 18 return super().__call__(*args, **kwargs)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/normal.py:50, in Normal.__init__(self, loc, scale, validate_args)
48 else:
49 batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
53 valid = constraint.check(value)
54 if not valid.all():
---> 55 raise ValueError(
56 f"Expected parameter {param} "
57 f"({type(value).__name__} of shape {tuple(value.shape)}) "
58 f"of distribution {repr(self)} "
59 f"to satisfy the constraint {repr(constraint)}, "
60 f"but found invalid values:\n{value}"
61 )
62 super(Distribution, self).__init__()
ValueError: Expected parameter scale (Tensor of shape (170,)) of distribution Normal(loc: torch.Size([170]), scale: torch.Size([170])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529])
Trace Shapes:
Param Sites:
Sample Sites:
a dist |
value |
bA dist |
value |
bR dist |
value |
bAR dist |
value |
sigma dist |
value |
data dist |
value 170 |