invalid value occurring during SVI

Hi,
I’m new to pyro and trying to understand the basics of Bayesian regression from a Bayesian Linear Regression example. (Bayesian Regression - Inference Algorithms (Part 2) — Pyro Tutorials 1.8.1 documentation)
But when I run the svi inference part, sometimes I get ValueError saying that the scale parameter of the normal distribution is invalid values. (inference code and full error message attached below)
It seems that the error occurs when a negative value is sampled for the scale parameter, which is strange because positive constraints are applied to all scale parameters.

So, my questions are

• Is there a way to avoid such a ValueError besides running the code over and over again until it works?
• How constraints on the sample work? I thought they work as “hard” constraints, so samples not satisfying the constraints are filtered before they are evaluated for the inference. But it seems like negative scale value (which violates constraints) is somehow sampled and used. So I’m misunderstanding something, but I’m not sure what it is.
``````def model(is_cont_africa, ruggedness, log_gdp=None):
a = pyro.sample("a", dist.Normal(0., 10.))
b_a = pyro.sample("bA", dist.Normal(0., 1.))
b_r = pyro.sample("bR", dist.Normal(0., 1.))
b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
sigma = pyro.sample("sigma", dist.Normal(0., 10.))

mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness

with pyro.plate("data", len(ruggedness)):
return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

def guide(is_cont_africa, ruggedness, log_gdp):
a_loc = pyro.param('a_loc', lambda: torch.tensor(0.))
a_scale = pyro.param('a_scale', lambda:torch.tensor(1.), constraint=constraints.positive)
b_loc = pyro.param('b_loc', lambda: torch.randn(3))
b_scale = pyro.param('b_scale', lambda: torch.ones(3), constraint=constraints.positive)
sigma_loc = pyro.param('sigma_loc', lambda: torch.tensor(1.), constraint=constraints.positive)

a = pyro.sample("a", dist.Normal(a_loc, a_scale))
b_a = pyro.sample("bA", dist.Normal(b_loc[0], b_scale[0]))
b_r = pyro.sample("bR", dist.Normal(b_loc[1], b_scale[1]))
b_ar = pyro.sample("bAR", dist.Normal(b_loc[2], b_scale[2]))
sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))
mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness

num_samples = 1000
predictive = Predictive(model, guide=guide, num_samples=num_samples)
svi_samples = {k: v.reshape(num_samples).detach().cpu().numpy()
for k, v in predictive(log_gdp, is_cont_africa, ruggedness).items()
if k != "obs"}

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
173 try:
--> 174     ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:

26 with self.clone():
---> 27     return func(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12     return fn(*args, **kwargs)

Input In [251], in model(is_cont_africa, ruggedness, log_gdp)
10 with pyro.plate("data", len(ruggedness)):
---> 11     return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/distributions/distribution.py:18, in DistributionMeta.__call__(cls, *args, **kwargs)
17         return result
---> 18 return super().__call__(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/normal.py:50, in Normal.__init__(self, loc, scale, validate_args)
49     batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
54         if not valid.all():
---> 55             raise ValueError(
56                 f"Expected parameter {param} "
57                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
58                 f"of distribution {repr(self)} "
59                 f"to satisfy the constraint {repr(constraint)}, "
60                 f"but found invalid values:\n{value}"
61             )
62 super(Distribution, self).__init__()

ValueError: Expected parameter scale (Tensor of shape (170,)) of distribution Normal(loc: torch.Size([170]), scale: torch.Size([170])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529])

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Input In [274], in <cell line: 5>()
3 num_samples = 1000
4 predictive = Predictive(model, guide=guide, num_samples=num_samples)
5 svi_samples = {k: v.reshape(num_samples).detach().cpu().numpy()
----> 6                for k, v in predictive(log_gdp, is_cont_africa, ruggedness).items()
7                if k != "obs"}

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
1106 # If we don't have any hooks, we want to skip the rest of the logic in
1107 # this function, and just call forward.
1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
1111 # Do not call functions when jit is used
1112 full_backward_hooks, non_full_backward_hooks = [], []

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs)
263     return_sites = None if not return_sites else return_sites
264     posterior_samples = _predictive(
265         self.guide,
266         posterior_samples,
(...)
271         model_kwargs=kwargs,
272     )
--> 273 return _predictive(
274     self.model,
275     posterior_samples,
276     self.num_samples,
277     return_sites=return_sites,
278     parallel=self.parallel,
279     model_args=args,
280     model_kwargs=kwargs,
281 )

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/infer/predictive.py:78, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
67 def _predictive(
68     model,
69     posterior_samples,
(...)
75     model_kwargs={},
76 ):
---> 78     max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
79     vectorize = pyro.plate(
80         "_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1
81     )
82     model_trace = prune_subsample_sites(
83         poutine.trace(model).get_trace(*model_args, **model_kwargs)
84     )

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/infer/predictive.py:21, in _guess_max_plate_nesting(model, args, kwargs)
15 """
16 Guesses max_plate_nesting by running the model once
17 without enumeration. This optimistically assumes static model
18 structure.
19 """
20 with poutine.block():
---> 21     model_trace = poutine.trace(model).get_trace(*args, **kwargs)
22 sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"]
24 dims = [
25     frame.dim
26     for site in sites
27     for frame in site["cond_indep_stack"]
28     if frame.vectorized
29 ]

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, **kwargs)
190 def get_trace(self, *args, **kwargs):
191     """
192     :returns: data structure
193     :rtype: pyro.poutine.Trace
(...)
196     Calls this poutine and returns its trace instead of the function's return value.
197     """
--> 198     self(*args, **kwargs)
199     return self.msngr.get_trace()

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.__call__(self, *args, **kwargs)
178         exc = exc_type("{}\n{}".format(exc_value, shapes))
179         exc = exc.with_traceback(traceback)
--> 180         raise exc from e
182         "_RETURN", name="_RETURN", type="return", value=ret
183     )
184 return ret

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
171     "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
172 )
173 try:
--> 174     ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176     exc_type, exc_value, traceback = sys.exc_info()

24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26     with self.clone():
---> 27         return func(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11     with context:
---> 12         return fn(*args, **kwargs)

Input In [251], in model(is_cont_africa, ruggedness, log_gdp)
8 mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
10 with pyro.plate("data", len(ruggedness)):
---> 11     return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/distributions/distribution.py:18, in DistributionMeta.__call__(cls, *args, **kwargs)
16     if result is not None:
17         return result
---> 18 return super().__call__(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/normal.py:50, in Normal.__init__(self, loc, scale, validate_args)
48 else:
49     batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
53         valid = constraint.check(value)
54         if not valid.all():
---> 55             raise ValueError(
56                 f"Expected parameter {param} "
57                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
58                 f"of distribution {repr(self)} "
59                 f"to satisfy the constraint {repr(constraint)}, "
60                 f"but found invalid values:\n{value}"
61             )
62 super(Distribution, self).__init__()

ValueError: Expected parameter scale (Tensor of shape (170,)) of distribution Normal(loc: torch.Size([170]), scale: torch.Size([170])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
-2.7529, -2.7529])
Trace Shapes:
Param Sites:
Sample Sites:
a dist     |
value     |
bA dist     |
value     |
bR dist     |
value     |
bAR dist     |
value     |
sigma dist     |
value     |
data dist     |
value 170 |
``````
``````sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))
``````

the `sigma` sample site (and perhaps other sample sites) need to have correct supports. e.g. you might instead use something like

``````sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0))
``````

Thank you! Now it works!!