Doing inference with two priors

I am trying to do a simple inference problem but with two priors (mu and sigma):

import torch
import pyro

def measurements():
    
    mu = pyro.sample('mu', pyro.distributions.Normal(0,50))
    sigma = pyro.sample('sigma', pyro.distributions.Uniform(0,25))
    x = pyro.sample('x', pyro.distributions.Normal(mu,sigma))
        
    
    return x

conditioned_measurements = pyro.condition(measurements, data={"x": -27.020})

from torch.distributions import constraints

def measurements_parametrized_guide_constrained():
    a = pyro.param("a", torch.tensor(0.))
    b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive)
    c = pyro.param("a", torch.tensor(0.), constraint=constraints.positive)
    d = pyro.param("b", torch.tensor(1.), constraint=constraints.positive)
    return pyro.sample("mu", pyro.distributions.Normal(a, b)),pyro.sample("sigma", pyro.distributions.Normal(c, d))

pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_measurements,
                     guide=measurements_parametrized_guide_constrained,
                     optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1}),
                     loss=pyro.infer.Trace_ELBO())


losses, a,b,c,d  = [], [], [], [], []
num_steps = 2500

for t in range(num_steps):
    losses.append(svi.step())
    a.append(pyro.param("a").item())
    b.append(pyro.param("b").item())
    c.append(pyro.param("c").item())
    d.append(pyro.param("d").item())

And I got the following error:

ValueError: Error while computing log_prob at site 'x':
The value argument to log_prob must be a Tensor
Trace Shapes:  
 Param Sites:  
Sample Sites:  
      mu dist |
        value |
     log_prob |
   sigma dist |
        value |
     log_prob |
       x dist |
        value |

The part I’m not sure is am I returning the two posteriors from the guide correctly but it seems the error is about something else…

Update: I have corrected the some bugs identified by the replies and myself as the following:

import torch
import pyro

def measurements():
    
    mu = pyro.sample('mu', pyro.distributions.Normal(0,50))
    sigma = pyro.sample('sigma', pyro.distributions.Uniform(0.001,25))
    x = pyro.sample('x', pyro.distributions.Normal(mu,sigma))
        
    
    return x

conditioned_measurements = pyro.condition(measurements, data={"x": torch.tensor(-27.020)})

from torch.distributions import constraints

def measurements_parametrized_guide_constrained():
    a = pyro.param("a", torch.tensor(0.))
    b = pyro.param("b", torch.tensor(50.), constraint=constraints.positive)
    c = pyro.param("c", torch.tensor(0.001), constraint=constraints.positive)
    d = pyro.param("d", torch.tensor(25.), constraint=constraints.positive)
    return pyro.sample("mu", pyro.distributions.Normal(a, b)),pyro.sample("sigma", pyro.distributions.Uniform(c, d))

pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_measurements,
                     guide=measurements_parametrized_guide_constrained,
                     optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1}),
                     loss=pyro.infer.Trace_ELBO())


losses, a,b,c,d  = [], [], [], [], []
num_steps = 2500

for t in range(num_steps):
    losses.append(svi.step())
    a.append(pyro.param("a").item())
    b.append(pyro.param("b").item())
    c.append(pyro.param("c").item())
    d.append(pyro.param("d").item())

unfortunately there is still error but different from before:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
    229                     try:
--> 230                         log_p = site["fn"].log_prob(
    231                             site["value"], *site["args"], **site["kwargs"]

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/uniform.py in log_prob(self, value)
     72         if self._validate_args:
---> 73             self._validate_sample(value)
     74         lb = self.low.le(value).type_as(self.low)

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
    276         if not support.check(value).all():
--> 277             raise ValueError('The value argument must be within the support')
    278 

ValueError: The value argument must be within the support

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-122-78d4c23ea3f5> in <module>
     33 
     34 for t in range(num_steps):
---> 35     losses.append(svi.step())
     36     a.append(pyro.param("a").item())
     37     b.append(pyro.param("b").item())

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
    143         # get loss and compute gradients
    144         with poutine.trace(param_only=True) as param_capture:
--> 145             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    146 
    147         params = set(

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    138         loss = 0.0
    139         # grab a trace from the generator
--> 140         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142                 model_trace, guide_trace

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
    180         else:
    181             for i in range(self.num_particles):
--> 182                 yield self._get_trace(model, guide, args, kwargs)

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
     55         against it.
     56         """
---> 57         model_trace, guide_trace = get_importance_trace(
     58             "flat", self.max_plate_nesting, model, guide, args, kwargs
     59         )

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     73     model_trace = prune_subsample_sites(model_trace)
     74 
---> 75     model_trace.compute_log_prob()
     76     guide_trace.compute_score_parts()
     77     if is_validation_enabled():

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
    234                         _, exc_value, traceback = sys.exc_info()
    235                         shapes = self.format_shapes(last_site=site["name"])
--> 236                         raise ValueError(
    237                             "Error while computing log_prob at site '{}':\n{}\n{}".format(
    238                                 name, exc_value, shapes

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
    228                 if "log_prob" not in site:
    229                     try:
--> 230                         log_p = site["fn"].log_prob(
    231                             site["value"], *site["args"], **site["kwargs"]
    232                         )

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/uniform.py in log_prob(self, value)
     71     def log_prob(self, value):
     72         if self._validate_args:
---> 73             self._validate_sample(value)
     74         lb = self.low.le(value).type_as(self.low)
     75         ub = self.high.gt(value).type_as(self.low)

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/distribution.py in _validate_sample(self, value)
    275         assert support is not None
    276         if not support.check(value).all():
--> 277             raise ValueError('The value argument must be within the support')
    278 
    279     def _get_checked_instance(self, cls, _instance=None):

ValueError: Error while computing log_prob at site 'sigma':
The value argument must be within the support
Trace Shapes:  
 Param Sites:  
Sample Sites:  
      mu dist |
        value |
     log_prob |
   sigma dist |
        value |


Hi @zyzhang1130,

Try conditioning on a tensor:


conditioned_measurements = pyro.condition(
    measurements, data={"x": torch.tensor(-27.020)}
)

Hi thank you for your suggestion.
now it shows the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

<ipython-input-43-11b0340d74a9> in measurements()
      4     sigma = pyro.sample('sigma', pyro.distributions.Uniform(0,25))
----> 5     x = pyro.sample('x', pyro.distributions.Normal(mu,sigma))
      6 

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
     17                 return result
---> 18         return super().__call__(*args, **kwargs)
     19 

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
     49             batch_shape = self.loc.size()
---> 50         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     51 

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     52                 if not constraint.check(getattr(self, param)).all():
---> 53                     raise ValueError("The parameter {} has invalid values".format(param))
     54         super(Distribution, self).__init__()

ValueError: The parameter scale has invalid values

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<ipython-input-53-fab342f7bb4b> in <module>
     10 
     11 for t in range(num_steps):
---> 12     losses.append(svi.step())
     13     a.append(pyro.param("a").item())
     14     b.append(pyro.param("b").item())

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
    143         # get loss and compute gradients
    144         with poutine.trace(param_only=True) as param_capture:
--> 145             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    146 
    147         params = set(

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    138         loss = 0.0
    139         # grab a trace from the generator
--> 140         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142                 model_trace, guide_trace

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
    180         else:
    181             for i in range(self.num_particles):
--> 182                 yield self._get_trace(model, guide, args, kwargs)

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
     55         against it.
     56         """
---> 57         model_trace, guide_trace = get_importance_trace(
     58             "flat", self.max_plate_nesting, model, guide, args, kwargs
     59         )

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     63         if detach:
     64             guide_trace.detach_()
---> 65         model_trace = poutine.trace(
     66             poutine.replay(model, trace=guide_trace), graph_type=graph_type
     67         ).get_trace(*args, **kwargs)

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    196         Calls this poutine and returns its trace instead of the function's return value.
    197         """
--> 198         self(*args, **kwargs)
    199         return self.msngr.get_trace()

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    178                 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
    179                 exc = exc.with_traceback(traceback)
--> 180                 raise exc from e
    181             self.msngr.trace.add_node(
    182                 "_RETURN", name="_RETURN", type="return", value=ret

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    172             )
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:
    176                 exc_type, exc_value, traceback = sys.exc_info()

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

<ipython-input-43-11b0340d74a9> in measurements()
      3     mu = pyro.sample('mu', pyro.distributions.Normal(0,50))
      4     sigma = pyro.sample('sigma', pyro.distributions.Uniform(0,25))
----> 5     x = pyro.sample('x', pyro.distributions.Normal(mu,sigma))
      6 
      7 

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/pyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
     16             if result is not None:
     17                 return result
---> 18         return super().__call__(*args, **kwargs)
     19 
     20     @property

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
     48         else:
     49             batch_shape = self.loc.size()
---> 50         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     51 
     52     def expand(self, batch_shape, _instance=None):

/opt/anaconda3/envs/unlearning/lib/python3.8/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     51                     continue  # skip checking lazily-constructed args
     52                 if not constraint.check(getattr(self, param)).all():
---> 53                     raise ValueError("The parameter {} has invalid values".format(param))
     54         super(Distribution, self).__init__()
     55 

ValueError: The parameter scale has invalid values
Trace Shapes:  
 Param Sites:  
Sample Sites:  
      mu dist |
        value |
   sigma dist |
        value |

Do you wish to instead write the following:

 c = pyro.param("c", torch.tensor(0.), constraint=constraints.positive)
 d = pyro.param("d", torch.tensor(1.), constraint=constraints.positive)

Hi, yes. Sorry for the low level mistake. But after I corrected it the error still persists.

Pyro is complaining that 0 is not a positive value. Either relax c to constraints.real or initialize to a positive value.

Yes I have noticed that. So in my latest code I changed it to

c = pyro.param("c", torch.tensor(0.001), constraint=constraints.positive)

(see the update in my original post). But there is still error. By the way I have also tried

c = pyro.param("c", torch.tensor(0.001), constraint=constraints.greater_than_eq(0))

and it didn’t work either.

I figured out the issue but in order to solve it I think i need to put multiple constraints on my posterior. Any idea how to do that?

What was the issue, and what was your solution?

Can you explain more what you mean by multiple constraints on your posterior? Pyro supports constraints on parameters and supports priors on latent variables, but I’m not sure what you mean by constraint on latent variable.

the issue is in measurements() the support of the uniform is confined to be [0.001,25], so during svi if d becomes larger than 25 it will throw is error. Im not sure why there is such a rule because i was simply trying to give a prior [0.001,25]. It was resolved by setting

  sigma = pyro.sample('sigma', pyro.distributions.Uniform(0.001,10000))

i.e. increase the support