invalid value occurring during SVI

Hi,
I’m new to pyro and trying to understand the basics of Bayesian regression from a Bayesian Linear Regression example. (Bayesian Regression - Inference Algorithms (Part 2) — Pyro Tutorials 1.8.1 documentation)
But when I run the svi inference part, sometimes I get ValueError saying that the scale parameter of the normal distribution is invalid values. (inference code and full error message attached below)
It seems that the error occurs when a negative value is sampled for the scale parameter, which is strange because positive constraints are applied to all scale parameters.

So, my questions are

  • Is there a way to avoid such a ValueError besides running the code over and over again until it works?
  • How constraints on the sample work? I thought they work as “hard” constraints, so samples not satisfying the constraints are filtered before they are evaluated for the inference. But it seems like negative scale value (which violates constraints) is somehow sampled and used. So I’m misunderstanding something, but I’m not sure what it is.
def model(is_cont_africa, ruggedness, log_gdp=None):
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Normal(0., 10.))
    
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
    
    with pyro.plate("data", len(ruggedness)):
        return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)
    
def guide(is_cont_africa, ruggedness, log_gdp):
    a_loc = pyro.param('a_loc', lambda: torch.tensor(0.))
    a_scale = pyro.param('a_scale', lambda:torch.tensor(1.), constraint=constraints.positive)
    b_loc = pyro.param('b_loc', lambda: torch.randn(3))
    b_scale = pyro.param('b_scale', lambda: torch.ones(3), constraint=constraints.positive)
    sigma_loc = pyro.param('sigma_loc', lambda: torch.tensor(1.), constraint=constraints.positive)
    
    a = pyro.sample("a", dist.Normal(a_loc, a_scale))
    b_a = pyro.sample("bA", dist.Normal(b_loc[0], b_scale[0]))
    b_r = pyro.sample("bR", dist.Normal(b_loc[1], b_scale[1]))
    b_ar = pyro.sample("bAR", dist.Normal(b_loc[2], b_scale[2]))
    sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))
    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness

num_samples = 1000
predictive = Predictive(model, guide=guide, num_samples=num_samples)
svi_samples = {k: v.reshape(num_samples).detach().cpu().numpy()
               for k, v in predictive(log_gdp, is_cont_africa, ruggedness).items()
               if k != "obs"}

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
    173 try:
--> 174     ret = self.fn(*args, **kwargs)
    175 except (ValueError, RuntimeError) as e:

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     26 with self.clone():
---> 27     return func(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
     11 with context:
---> 12     return fn(*args, **kwargs)

Input In [251], in model(is_cont_africa, ruggedness, log_gdp)
     10 with pyro.plate("data", len(ruggedness)):
---> 11     return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/distributions/distribution.py:18, in DistributionMeta.__call__(cls, *args, **kwargs)
     17         return result
---> 18 return super().__call__(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/normal.py:50, in Normal.__init__(self, loc, scale, validate_args)
     49     batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     54         if not valid.all():
---> 55             raise ValueError(
     56                 f"Expected parameter {param} "
     57                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     58                 f"of distribution {repr(self)} "
     59                 f"to satisfy the constraint {repr(constraint)}, "
     60                 f"but found invalid values:\n{value}"
     61             )
     62 super(Distribution, self).__init__()

ValueError: Expected parameter scale (Tensor of shape (170,)) of distribution Normal(loc: torch.Size([170]), scale: torch.Size([170])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529])

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Input In [274], in <cell line: 5>()
      3 num_samples = 1000
      4 predictive = Predictive(model, guide=guide, num_samples=num_samples)
      5 svi_samples = {k: v.reshape(num_samples).detach().cpu().numpy()
----> 6                for k, v in predictive(log_gdp, is_cont_africa, ruggedness).items()
      7                if k != "obs"}

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs)
    263     return_sites = None if not return_sites else return_sites
    264     posterior_samples = _predictive(
    265         self.guide,
    266         posterior_samples,
   (...)
    271         model_kwargs=kwargs,
    272     )
--> 273 return _predictive(
    274     self.model,
    275     posterior_samples,
    276     self.num_samples,
    277     return_sites=return_sites,
    278     parallel=self.parallel,
    279     model_args=args,
    280     model_kwargs=kwargs,
    281 )

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/infer/predictive.py:78, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
     67 def _predictive(
     68     model,
     69     posterior_samples,
   (...)
     75     model_kwargs={},
     76 ):
     77     model = torch.no_grad()(poutine.mask(model, mask=False))
---> 78     max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
     79     vectorize = pyro.plate(
     80         "_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1
     81     )
     82     model_trace = prune_subsample_sites(
     83         poutine.trace(model).get_trace(*model_args, **model_kwargs)
     84     )

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/infer/predictive.py:21, in _guess_max_plate_nesting(model, args, kwargs)
     15 """
     16 Guesses max_plate_nesting by running the model once
     17 without enumeration. This optimistically assumes static model
     18 structure.
     19 """
     20 with poutine.block():
---> 21     model_trace = poutine.trace(model).get_trace(*args, **kwargs)
     22 sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"]
     24 dims = [
     25     frame.dim
     26     for site in sites
     27     for frame in site["cond_indep_stack"]
     28     if frame.vectorized
     29 ]

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, **kwargs)
    190 def get_trace(self, *args, **kwargs):
    191     """
    192     :returns: data structure
    193     :rtype: pyro.poutine.Trace
   (...)
    196     Calls this poutine and returns its trace instead of the function's return value.
    197     """
--> 198     self(*args, **kwargs)
    199     return self.msngr.get_trace()

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.__call__(self, *args, **kwargs)
    178         exc = exc_type("{}\n{}".format(exc_value, shapes))
    179         exc = exc.with_traceback(traceback)
--> 180         raise exc from e
    181     self.msngr.trace.add_node(
    182         "_RETURN", name="_RETURN", type="return", value=ret
    183     )
    184 return ret

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
    170 self.msngr.trace.add_node(
    171     "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
    172 )
    173 try:
--> 174     ret = self.fn(*args, **kwargs)
    175 except (ValueError, RuntimeError) as e:
    176     exc_type, exc_value, traceback = sys.exc_info()

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)

Input In [251], in model(is_cont_africa, ruggedness, log_gdp)
      8 mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness
     10 with pyro.plate("data", len(ruggedness)):
---> 11     return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pyro/distributions/distribution.py:18, in DistributionMeta.__call__(cls, *args, **kwargs)
     16     if result is not None:
     17         return result
---> 18 return super().__call__(*args, **kwargs)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/normal.py:50, in Normal.__init__(self, loc, scale, validate_args)
     48 else:
     49     batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/distributions/distribution.py:55, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     53         valid = constraint.check(value)
     54         if not valid.all():
---> 55             raise ValueError(
     56                 f"Expected parameter {param} "
     57                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     58                 f"of distribution {repr(self)} "
     59                 f"to satisfy the constraint {repr(constraint)}, "
     60                 f"but found invalid values:\n{value}"
     61             )
     62 super(Distribution, self).__init__()

ValueError: Expected parameter scale (Tensor of shape (170,)) of distribution Normal(loc: torch.Size([170]), scale: torch.Size([170])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([-2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529, -2.7529,
        -2.7529, -2.7529])
Trace Shapes:      
 Param Sites:      
Sample Sites:      
       a dist     |
        value     |
      bA dist     |
        value     |
      bR dist     |
        value     |
     bAR dist     |
        value     |
   sigma dist     |
        value     |
    data dist     |
        value 170 |
sigma = pyro.sample("sigma", dist.Normal(sigma_loc, torch.tensor(0.05)))

the sigma sample site (and perhaps other sample sites) need to have correct supports. e.g. you might instead use something like

sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0))

Thank you! Now it works!!