How to get parameters generated by easyguide using .parameters method?

Hello all

I have written an easyguide for a small Regression problem. I could not get the guide parameters by initializing it.

Then I found a way of initializing parameters looking at Example: Sparse Bayesian Linear Regression — Pyro Tutorials 1.8.4 documentation. That did work. But the parameters did not show up using .parameters method.

Minimal reproducible snippet -

In [1]: import pyro
   ...: import torch
   ...: import pyro.distributions as dist
   ...: from pyro.contrib.easyguide import easy_guide, EasyGuide
   ...: from pyro.nn import PyroModule, PyroSample, PyroParam
   ...: from torch.distributions import constraints
   ...: import numpy as np
   ...: 
   ...: torch.manual_seed(42)
   ...: pyro.set_rng_seed(42)
   ...: pyro.__version__
Out[1]: '1.6.0'

In [2]: class BayesianRegression(PyroModule):
   ...:     def __init__(self, in_features, out_features):
   ...:         super().__init__()
   ...:         self.linear = PyroModule[torch.nn.Linear](in_features, out_features)
   ...:         self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
   ...:         self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
   ...: 
   ...:     def forward(self, x, full_size, y=None):
   ...:         sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
   ...:         mean = self.linear(x).squeeze(-1)
   ...:         with pyro.plate("data", size=full_size, subsample_size=x.shape[0]):
   ...:             obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
   ...:         return mean
   ...: 

In [3]: base_regression_model = BayesianRegression(1, 1)

In [4]: @easy_guide(base_regression_model)
   ...: def regression_guide(self, x, full_size, y=None):
   ...:     group = self.group(match=".*")
   ...:     loc = pyro.param("loc", torch.randn(group.event_shape))
   ...:     scale = pyro.param("scale", torch.ones(group.event_shape)*0.01, constraint=constraints.positive)
   ...:     group.sample("joint", dist.Normal(loc=loc, scale=scale).to_event(1))
   ...: 

In [5]: pyro.clear_param_store()
   ...: regression_guide(x=torch.ones(10, 1), full_size=100)
   ...: list(regression_guide.parameters())  # Unable to get parameters
Out[5]: []

In [6]: pyro.clear_param_store()
   ...: with pyro.poutine.block(), pyro.poutine.trace(param_only=True) as param_capture: 
   ...:     regression_guide(x=torch.ones(10, 1), full_size=100)
   ...: params = list([pyro.param(name).unconstrained() for name in param_capture.trace])
   ...: params
Out[6]: 
[tensor([ 1.8928,  1.3067, -0.0662, -0.4235, -2.3768,  0.0641, -0.3435,  1.2287,
         -0.2754, -0.2109,  0.9287, -0.2282, -1.2179], requires_grad=True),
 tensor([-4.6052, -4.6052, -4.6052, -4.6052, -4.6052, -4.6052, -4.6052, -4.6052,
         -4.6052, -4.6052, -4.6052, -4.6052, -4.6052], requires_grad=True)]

In [7]: list(regression_guide.parameters())  # still cannot get the parameters
Out[7]: []

In [8]: pyro.clear_param_store()
   ...: regression_guide = pyro.infer.autoguide.AutoNormal(base_regression_model)
   ...: regression_guide(x=torch.ones(10, 1), full_size=100)
   ...: list(regression_guide.parameters())  # Getting params from AutoNormal works well
Out[8]: 
[Parameter containing:
 tensor(0., requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.], requires_grad=True),
 Parameter containing:
 tensor([[0.]], requires_grad=True),
 Parameter containing:
 tensor([0.], requires_grad=True),
 Parameter containing:
 tensor(-2.3026, requires_grad=True),
 Parameter containing:
 tensor([-2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026, -2.3026,
         -2.3026, -2.3026, -2.3026, -2.3026], requires_grad=True),
 Parameter containing:
 tensor([[-2.3026]], requires_grad=True),
 Parameter containing:
 tensor([-2.3026], requires_grad=True)]

So, how can I get parameters generated by easyguide using .parameters method?
Thanks

easyguides don’t store values of pyro.param statements. You can access the parameters created in your guide in the usual way by looking up their values in the global parameter store with pyro.param(name).

If you want your guide object to have parameters attached to it, you can create a subclass of EasyGuide and create parameters in the __init__ method, just like you did with your model:

class MyGuide(EasyGuide):
    def __init__(self):
        self.loc = PyroParam(...)
        ...
    def guide(self, x, full_size, y=None):
        group = self.group(match=".*")
        group.sample("joint", dist.Normal(loc=self.loc, scale=self.scale).to_event(1))

Hi @eb8680_2 . Thanks for your reply.
I tried using PyroParam, but it resulted in an error.

Here is the minimal reproducible snippet

In [1]: import pyro
   ...: import torch
   ...: import pyro.distributions as dist
   ...: from pyro.contrib.easyguide import easy_guide, EasyGuide
   ...: from pyro.nn import PyroModule, PyroSample, PyroParam
   ...: from torch.distributions import constraints
   ...: import numpy as np
   ...: 
   ...: torch.manual_seed(42)
   ...: pyro.set_rng_seed(42)
   ...: pyro.__version__
Out[1]: '1.6.0'

In [2]: class BayesianRegression(PyroModule):
   ...:     def __init__(self, in_features, out_features):
   ...:         super().__init__()
   ...:         self.linear = PyroModule[torch.nn.Linear](in_features, out_features)
   ...:         self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
   ...:         self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))
   ...: 
   ...:     def forward(self, x, full_size, y=None):
   ...:         sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
   ...:         mean = self.linear(x).squeeze(-1)
   ...:         with pyro.plate("data", size=full_size, subsample_size=x.shape[0]):
   ...:             obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
   ...:         return mean

In [3]: base_regression_model = BayesianRegression(1, 1)

In [4]: class RegressionGuide(EasyGuide):
   ...:     def __init__(self, model):
   ...:         super().__init__(model)
   ...: 
   ...:     def guide(self, x, full_size, y=None):
   ...:         group = self.group(match=".*")
   ...:         loc = PyroParam(torch.randn(group.event_shape))
   ...:         scale = PyroParam(torch.ones(group.event_shape)*0.01, constraint=constraints.positive)
   ...:         group.sample("joint", dist.Normal(loc=loc, scale=scale).to_event(1))

In [5]: pyro.clear_param_store()
   ...: guide = RegressionGuide(base_regression_model)
   ...: guide(torch.randn((10, 1)), full_size=100)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-5-d8c78b92d548> in <module>
      1 pyro.clear_param_store()
      2 guide = RegressionGuide(base_regression_model)
----> 3 guide(torch.randn((10, 1)), full_size=100)

~/miniconda3/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    411     def __call__(self, *args, **kwargs):
    412         with self._pyro_context:
--> 413             return super().__call__(*args, **kwargs)
    414 
    415     def __getattr__(self, name):

~/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~/miniconda3/lib/python3.8/site-packages/pyro/contrib/easyguide/easyguide.py in forward(self, *args, **kwargs)
     97         if self.prototype_trace is None:
     98             self._setup_prototype(*args, **kwargs)
---> 99         result = self.guide(*args, **kwargs)
    100         self.plates.clear()
    101         return result

<ipython-input-4-45cb7318750e> in guide(self, x, full_size, y)
      7         loc = PyroParam(torch.randn(group.event_shape))
      8         scale = PyroParam(torch.ones(group.event_shape)*0.01, constraint=constraints.positive)
----> 9         group.sample("joint", dist.Normal(loc=loc, scale=scale).to_event(1))
     10 

~/miniconda3/lib/python3.8/site-packages/pyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
     16             if result is not None:
     17                 return result
---> 18         return super().__call__(*args, **kwargs)
     19 
     20     @property

~/miniconda3/lib/python3.8/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
     43 
     44     def __init__(self, loc, scale, validate_args=None):
---> 45         self.loc, self.scale = broadcast_all(loc, scale)
     46         if isinstance(loc, Number) and isinstance(scale, Number):
     47             batch_shape = torch.Size()

~/miniconda3/lib/python3.8/site-packages/torch/distributions/utils.py in broadcast_all(*values)
     27     if not all(isinstance(v, torch.Tensor) or has_torch_function((v,)) or isinstance(v, Number)
     28                for v in values):
---> 29         raise ValueError('Input arguments must all be instances of numbers.Number, '
     30                          'torch.Tensor or objects implementing __torch_function__.')
     31     if not all([isinstance(v, torch.Tensor) or has_torch_function((v,)) for v in values]):

ValueError: Input arguments must all be instances of numbers.Number, torch.Tensor or objects implementing __torch_function__.

Looking at the error message, I think loc parameter for Normal distribution needs to be a tensor or a number. So, should I sample from loc and scale after using PyroParam?

Weird thing though, if I replace my RegressionGuide to use pyro.param, the code works -

class RegressionGuide(EasyGuide):
    def __init__(self, model):
        super().__init__(model)

    def guide(self, x, full_size, y=None):
        group = self.group(match=".*")
        loc = pyro.param("loc", torch.randn(group.event_shape))
        scale = pyro.param('scale', torch.ones(group.event_shape)*0.01, constraint=constraints.positive)
        group.sample("joint", dist.Normal(loc=loc, scale=scale).to_event(1))

I think I am not using PyroParam properly.

Can this be a feature request?
Thanks

As discussed in the documentation, PyroParams represent deferred calls to pyro.param: setting them as named attributes of a pyro.nn.PyroModule (of which EasyGuides and AutoGuides are subclasses) associates them with a name, and accessing those attributes actually calls pyro.param. PyroSample behaves similarly.

This enables a more PyTorch-friendly programming style where parameters may be defined or initialized in the constructor of a model instead of its main body.

Your example code incorrectly treats PyroParam as an eager pyro.param call that immediately returns a nn.Parameter(). Here is a modified version of your guide that illustrates two ways to use PyroParam correctly:

class RegressionGuide(EasyGuide):
    def __init__(self, model):
        super().__init__(model)
        self.loc = PyroParam(lambda: torch.randn(self.group(match=".*").event_shape))

    @PyroParam(constraint=constraints.positive)
    def scale(self):
        return 0.01*torch.ones(self.group(match=".*").event_shape)

    def guide(self, x, full_size, y=None):
        group = self.group(match=".*")
        group.sample("joint", dist.Normal(loc=self.loc, scale=self.scale).to_event(1))

Note that the lambda functions defining initial values for self.loc and self.scale will only be called the first time those attributes are accessed, so it’s OK for them to refer to self.group(match=".*") in this case since that’s first created before self.loc and self.scale are accessed in guide.

Can this be a feature request?

You’re welcome to open a feature request issue on GitHub with a more detailed proposal. Personally, I would recommend writing your guide as a subclass of EasyGuide and defining some parameters in __init__ - that way you don’t have parameter initialization logic cluttering up the main body of your guide program.

1 Like

Thank you for a detailed answer @eb8680_2. It helped a lot :slight_smile: