Gamma distribution modeling concentration and rate - Predictive error

I’m working on modeling a gamma distribution where both parameters are functions of the data. I’m currently using just a simple 2 feature model after generating random data following this example.

I am also mostly following the guide from Bayesian Regression - Introduction (Part 1) except for just changing the distribution and modeling both gamma parameters as a function of the data. First question, is the following model definition correct for that?

class BayesianRegression_LogGamma_shape(PyroModule):
def __init__(self, in_features, out_features):
    super().__init__()
    self.linear = PyroModule[nn.Linear](in_features, out_features)
    self.linear.weight = PyroSample(dist.Normal(0., 2.).expand([out_features, in_features]).to_event(2))
    self.linear.bias = PyroSample(dist.Normal(0., 2.).expand([out_features]).to_event(1))
    self.linear_shape = PyroModule[nn.Linear](in_features, out_features)
    self.linear_shape.weight = PyroSample(dist.Normal(0., 2.).expand([out_features, in_features]).to_event(2))
    self.linear_shape.bias = PyroSample(dist.Normal(0., 2.).expand([out_features]).to_event(1))

def forward(self, x, y=None):
    
    mu = self.linear(x).squeeze(-1).exp()
    shape = self.linear_shape(x).squeeze(-1).exp()
    
    with pyro.plate("data", x.shape[0]):
        obs = pyro.sample("obs", dist.Gamma(concentration = shape, rate = shape / mu), obs=y)
    return mu, shape

Trace_ELBO loss looks reasonable compared to a simpler version so I think it’s working correctly. The model does train without error. However, I get an error when trying to run the Predictive class.

def summary(samples):
site_stats = {}
for k, v in samples.items():
    site_stats[k] = {
        "mean": torch.mean(v, 0),
        "std": torch.std(v, 0),
        "median": torch.median(v, 0).values,
        "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
        "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
    }
return site_stats


predictive = Predictive(model, guide=guide, num_samples=1000, return_sites=("linear.weight", "obs", "_RETURN", 'linear.bias', 'mu', 'shape'))
#

with data.converter_train.make_torch_dataloader(batch_size=BATCH_SIZE*10) as train_dataloader:
    train_dataloader_iter = iter(train_dataloader)
    pd_batch = next(train_dataloader_iter)
    pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1).float()    
    inputs = pd_batch['features'].to(device)
    labels = pd_batch[y_name].to(device)
    samples = predictive(inputs)
    pred_summary = summary(samples)

The error is “TypeError: expected Tensor as element 0 in argument 0, but got tuple”

    ---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-34-d6e380deacde> in <module>
     22     print(inputs)
     23     labels = pd_batch[y_name].to(device)
---> 24     samples = predictive(inputs)
     25     pred_summary = summary(samples)

C:\ProgramData\Anaconda3\envs\CurvGH_202010\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\predictive.py in forward(self, *args, **kwargs)
    204                                             parallel=self.parallel, model_args=args, model_kwargs=kwargs)
    205         return _predictive(self.model, posterior_samples, self.num_samples, return_sites=return_sites,
--> 206                            parallel=self.parallel, model_args=args, model_kwargs=kwargs)
    207 
    208     def get_samples(self, *args, **kwargs):

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\predictive.py in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
     91     if not parallel:
     92         return _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples,
---> 93                                       return_site_shapes, return_trace=False)
     94 
     95     trace = poutine.trace(poutine.condition(vectorize(model), reshaped_samples))\

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\predictive.py in _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace)
     46     else:
     47         return {site: torch.stack([s[site] for s in collected]).reshape(shape)
---> 48                 for site, shape in return_site_shapes.items()}
     49 
     50 

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\predictive.py in <dictcomp>(.0)
     46     else:
     47         return {site: torch.stack([s[site] for s in collected]).reshape(shape)
---> 48                 for site, shape in return_site_shapes.items()}
     49 
     50 

TypeError: expected Tensor as element 0 in argument 0, but got tuple

The forward method was the issue. It needs to return a tensor of mu and shape, not just a tuple.

def forward(self, x, y=None):
    
    mu = self.linear(x).squeeze(-1).exp()
    shape = self.linear_shape(x).squeeze(-1).exp()
    
    with pyro.plate("data", x.shape[0]):
        obs = pyro.sample("obs", dist.Gamma(concentration = shape, rate = shape / mu), obs=y)
    return torch.cat((mu, shape), 0)
2 Likes

Gosh… I have been thinking about this issue for a while. You are right. Because you include _RETURN in return_sites of Predictive, you need to make sure that the returned value of model is a tensor.

I am now working on extending this model to allow for zeros in the outcome. I thought dist.ZeroInflatedDistribution would allow for this.

class BayesianRegression_LogGamma_shape_zeroInf(PyroModule):
def __init__(self, in_features, out_features):
    super().__init__()
    self.linear = PyroModule[nn.Linear](in_features, out_features)
    self.linear.weight = PyroSample(dist.Normal(0., 2.).expand([out_features, in_features]).to_event(2))
    self.linear.bias = PyroSample(dist.Normal(0., 2.).expand([out_features]).to_event(1))
    self.linear_shape = PyroModule[nn.Linear](in_features, out_features)
    self.linear_shape.weight = PyroSample(dist.Normal(0., 2.).expand([out_features, in_features]).to_event(2))
    self.linear_shape.bias = PyroSample(dist.Normal(0., 2.).expand([out_features]).to_event(1))

def forward(self, x, y=None):
    
    mu = self.linear(x).squeeze(-1).exp()
    shape = self.linear_shape(x).squeeze(-1).exp()
    gate = pyro.sample("gate", dist.Uniform(.0001, 1.0))#.exp()
    
    with pyro.plate("data", x.shape[0]):
        obs = pyro.sample("obs", dist.ZeroInflatedDistribution(dist.Gamma(concentration = shape, rate = shape / mu), gate = gate), obs=y)
    return torch.cat((mu, shape), 0) 

However I’m getting an error. The value argument must be within the support. If I add any amount where y = 0, then it works again so it’s definitely the zeros causing the error. How can I allow for zeros?

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_struct.py in compute_log_prob(self, site_filter)
215                     try:
--> 216                         log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
217                     except ValueError as e:

~\AppData\Roaming\Python\Python37\site-packages\pyro\distributions\zero_inflated.py in log_prob(self, value)
 64         if self._validate_args:
---> 65             self._validate_sample(value)
 66 

C:\ProgramData\Anaconda3\envs\lib\site-packages\torch\distributions\distribution.py in _validate_sample(self, value)
276         if not support.check(value).all():
--> 277             raise ValueError('The value argument must be within the support')
278 

ValueError: The value argument must be within the support

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<timed eval> in <module>

<ipython-input-261-05cb1fa637bb> in train_and_evaluate_SVI(svi, lr)
 46             if (epoch + 1) % 100 == 0: print('-' * 10)
 47 
---> 48             train_loss = train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)
 49 
 50             #val_loss = evaluate(model, criterion, val_dataloader_iter, validation_steps, device)

<ipython-input-261-05cb1fa637bb> in train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)
 14 
 15         # Track history in training
---> 16         loss = svi.step(inputs, labels)
 17 
 18         # statistics

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\svi.py in step(self, *args, **kwargs)
126         # get loss and compute gradients
127         with poutine.trace(param_only=True) as param_capture:
--> 128             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
129 
130         params = set(site["value"].unconstrained()

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
129         loss = 0.0
130         # grab a trace from the generator
--> 131         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
132             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
133             loss += loss_particle / self.num_particles

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\elbo.py in _get_traces(self, model, guide, args, kwargs)
168         else:
169             for i in range(self.num_particles):
--> 170                 yield self._get_trace(model, guide, args, kwargs)

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
 56         """
 57         model_trace, guide_trace = get_importance_trace(
---> 58             "flat", self.max_plate_nesting, model, guide, args, kwargs)
 59         if is_validation_enabled():
 60             check_if_enumerated(guide_trace)

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
 53     model_trace = prune_subsample_sites(model_trace)
 54 
---> 55     model_trace.compute_log_prob()
 56     guide_trace.compute_score_parts()
 57     if is_validation_enabled():

~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_struct.py in compute_log_prob(self, site_filter)
219                         shapes = self.format_shapes(last_site=site["name"])
220                         raise ValueError("Error while computing log_prob at site '{}':\n{}\n{}"
--> 221                                          .format(name, exc_value, shapes)).with_traceback(traceback) from e
222                     site["unscaled_log_prob"] = log_p
223                     log_p = scale_and_mask(log_p, site["scale"], site["mask"])

~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_struct.py in compute_log_prob(self, site_filter)
214                 if "log_prob" not in site:
215                     try:
--> 216                         log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
217                     except ValueError as e:
218                         _, exc_value, traceback = sys.exc_info()

~\AppData\Roaming\Python\Python37\site-packages\pyro\distributions\zero_inflated.py in log_prob(self, value)
 63     def log_prob(self, value):
 64         if self._validate_args:
---> 65             self._validate_sample(value)
 66 
 67         if 'gate' in self.__dict__:

C:\ProgramData\Anaconda3\envs\lib\site-packages\torch\distributions\distribution.py in _validate_sample(self, value)
275         assert support is not None
276         if not support.check(value).all():
--> 277             raise ValueError('The value argument must be within the support')
278 
279     def _get_checked_instance(self, cls, _instance=None):

ValueError: Error while computing log_prob at site 'obs':
The value argument must be within the support
       Trace Shapes:           
        Param Sites:           
       Sample Sites:           
  linear.weight dist      | 1 2
               value      | 1 2
            log_prob      |    
    linear.bias dist      | 1  
               value      | 1  
            log_prob      |    
linear_shape.weight dist      | 1 2
               value      | 1 2
            log_prob      |    
  linear_shape.bias dist      | 1  
               value      | 1  
            log_prob      |    
           gate dist      |    
               value      |    
            log_prob      |    
            obs dist 1028 |    
               value 1028 |

My understanding now after reading the documentation and source code some more is that the zero inflated distribution class only works for distributions that already allow for zeros. I’m now working on writing a custom distribution class.

class GammaHurdle(dist.torch_distribution.TorchDistribution):
    
    # theta: probability of a zero
    arg_constraints = {'concentration': torch.distributions.constraints.positive, 
                       'rate': torch.distributions.constraints.positive, 
                       'theta': torch.distributions.constraints.interval(0., 1.)}
    support = torch.distributions.constraints._Interval(0, float("inf"))
    has_rsample = True
    
    def __init__(self, concentration, rate, theta, validate_args=None):
        self.concentration, self.rate, self.theta = broadcast_all(concentration, rate, theta)
        if isinstance(concentration, Number) and isinstance(rate, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.concentration.size()
        super(GammaHurdle, self).__init__(batch_shape, validate_args=validate_args)
            
    def log_prob(self, value):
        value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
        if self._validate_args:
            self._validate_sample(value)
        
        return torch.where(value > 0, 
                           (torch.log(1 - self.theta) + self.concentration * torch.log(self.rate) +
                            (self.concentration - 1) * torch.log(value) -
                            self.rate * value - torch.lgamma(self.concentration)), 
                           torch.log(self.theta))
    
    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        mask = torch.rand(shape) < self.theta.expand(shape)        
        value = torch._standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
        value[mask] = 0
        value.detach().clamp_(min=torch.finfo(value.dtype).tiny)  # do not record in autograd graph
        return value

The new PyroModule class:

class BayesianRegression_LogGamma_shape_zeroInf(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 2.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 2.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        
        mu = self.linear(x).squeeze(-1).exp()
        shape = pyro.sample("shape", dist.Gamma(.01, .01))#.exp()#self.linear_shape(x).squeeze(-1).exp()
        theta = pyro.sample("theta", dist.Uniform(.000001, 1.0))#.exp()
        
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y)
        return mu# torch.cat((mu, shape), 0)    

I’m backing to getting an error: ValueError: The parameter loc has invalid values. Again, testing the code while adding a positive amount to any zeros in y results in the model training. I am currently stuck figuring out why loc is invalid though.

Full error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_messenger.py in __call__(self, *args, **kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError) as e:

~\AppData\Roaming\Python\Python37\site-packages\pyro\nn\module.py in __call__(self, *args, **kwargs)
    412         with self._pyro_context:
--> 413             return super().__call__(*args, **kwargs)
    414 

C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in forward(self, *args, **kwargs)
    725 
--> 726         latent = self.sample_latent(*args, **kwargs)
    727         plates = self._create_plates(*args, **kwargs)

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in sample_latent(self, *args, **kwargs)
    687         """
--> 688         pos_dist = self.get_posterior(*args, **kwargs)
    689         return pyro.sample("_{}_latent".format(self._pyro_name), pos_dist, infer={"is_auxiliary": True})

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in get_posterior(self, *args, **kwargs)
    903         """
--> 904         return dist.Normal(self.loc, self.scale).to_event(1)
    905 

~\AppData\Roaming\Python\Python37\site-packages\pyro\distributions\distribution.py in __call__(cls, *args, **kwargs)
     17                 return result
---> 18         return super().__call__(*args, **kwargs)
     19 

C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\distributions\normal.py in __init__(self, loc, scale, validate_args)
     49             batch_shape = self.loc.size()
---> 50         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     51 

C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\distributions\distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     52                 if not constraint.check(getattr(self, param)).all():
---> 53                     raise ValueError("The parameter {} has invalid values".format(param))
     54         super(Distribution, self).__init__()

ValueError: The parameter loc has invalid values

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
<timed eval> in <module>

<ipython-input-26-ca50fa2dbbbe> in train_and_evaluate_SVI(svi, lr)
     46             if (epoch + 1) % 100 == 0: print('-' * 10)
     47 
---> 48             train_loss = train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)
     49 
     50             #val_loss = evaluate(model, criterion, val_dataloader_iter, validation_steps, device)

<ipython-input-26-ca50fa2dbbbe> in train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)
     14 
     15         # Track history in training
---> 16         loss = svi.step(inputs, labels)
     17 
     18         # statistics

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\svi.py in step(self, *args, **kwargs)
    126         # get loss and compute gradients
    127         with poutine.trace(param_only=True) as param_capture:
--> 128             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    129 
    130         params = set(site["value"].unconstrained()

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    129         loss = 0.0
    130         # grab a trace from the generator
--> 131         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    132             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
    133             loss += loss_particle / self.num_particles

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\elbo.py in _get_traces(self, model, guide, args, kwargs)
    168         else:
    169             for i in range(self.num_particles):
--> 170                 yield self._get_trace(model, guide, args, kwargs)

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
     56         """
     57         model_trace, guide_trace = get_importance_trace(
---> 58             "flat", self.max_plate_nesting, model, guide, args, kwargs)
     59         if is_validation_enabled():
     60             check_if_enumerated(guide_trace)

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     42     and the model that is run against it.
     43     """
---> 44     guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
     45     if detach:
     46         guide_trace.detach_()

~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_messenger.py in get_trace(self, *args, **kwargs)
    185         Calls this poutine and returns its trace instead of the function's return value.
    186         """
--> 187         self(*args, **kwargs)
    188         return self.msngr.get_trace()

~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_messenger.py in __call__(self, *args, **kwargs)
    169                 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
    170                 exc = exc.with_traceback(traceback)
--> 171                 raise exc from e
    172             self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
    173         return ret

~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_messenger.py in __call__(self, *args, **kwargs)
    163                                       args=args, kwargs=kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError) as e:
    167                 exc_type, exc_value, traceback = sys.exc_info()

~\AppData\Roaming\Python\Python37\site-packages\pyro\nn\module.py in __call__(self, *args, **kwargs)
    411     def __call__(self, *args, **kwargs):
    412         with self._pyro_context:
--> 413             return super().__call__(*args, **kwargs)
    414 
    415     def __getattr__(self, name):

C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in forward(self, *args, **kwargs)
    724             self._setup_prototype(*args, **kwargs)
    725 
--> 726         latent = self.sample_latent(*args, **kwargs)
    727         plates = self._create_plates(*args, **kwargs)
    728 

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in sample_latent(self, *args, **kwargs)
    686         base ``model``.
    687         """
--> 688         pos_dist = self.get_posterior(*args, **kwargs)
    689         return pyro.sample("_{}_latent".format(self._pyro_name), pos_dist, infer={"is_auxiliary": True})
    690 

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in get_posterior(self, *args, **kwargs)
    902         Returns a diagonal Normal posterior distribution.
    903         """
--> 904         return dist.Normal(self.loc, self.scale).to_event(1)
    905 
    906     def _loc_scale(self, *args, **kwargs):

~\AppData\Roaming\Python\Python37\site-packages\pyro\distributions\distribution.py in __call__(cls, *args, **kwargs)
     16             if result is not None:
     17                 return result
---> 18         return super().__call__(*args, **kwargs)
     19 
     20     @property

C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\distributions\normal.py in __init__(self, loc, scale, validate_args)
     48         else:
     49             batch_shape = self.loc.size()
---> 50         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     51 
     52     def expand(self, batch_shape, _instance=None):

C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\distributions\distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     51                     continue  # skip checking lazily-constructed args
     52                 if not constraint.check(getattr(self, param)).all():
---> 53                     raise ValueError("The parameter {} has invalid values".format(param))
     54         super(Distribution, self).__init__()
     55 

ValueError: The parameter loc has invalid values
           Trace Shapes:  
            Param Sites:  
  AutoDiagonalNormal.loc 5
AutoDiagonalNormal.scale 5
           Sample Sites:

I just guess what you need is support = constraints.greater_than_eq(0.) rather than torch.distributions.constraints._Interval(0, float("inf"))? Your explanation about issues that you faced seems reasonable to me. Gamma support is positive so it does not work with ZeroInflatedDistribution. Probably you can make a wrapper for Gamma to bypass that issue

class RelaxedGamma(Gamma):
    support = constraints.greater_than_eq(0.)

Thanks for your responses. I wrote a class that inherits from Gamma. Still getting the same error. Is there a way to get more info on the loc estimates in the middle of training?

class GammaHurdle2(dist.Gamma):
    
    # theta: probability of a zero
    arg_constraints = {'concentration': torch.distributions.constraints.positive, 
                       'rate': torch.distributions.constraints.positive, 
                       'theta': torch.distributions.constraints.interval(0., .99)}
    #support = torch.distributions.constraints._Interval(0, float("inf"))
    support = torch.distributions.constraints.greater_than_eq(0.)
    has_rsample = True
    
    def __init__(self, concentration, rate, theta, validate_args=None):
        self.concentration, self.rate, self.theta = broadcast_all(concentration, rate, theta)
        if isinstance(concentration, Number) and isinstance(rate, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.concentration.size()
        super(GammaHurdle2, self).__init__(concentration = concentration, rate = rate, validate_args=validate_args)
            
    def log_prob(self, value):
        value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)

        if self._validate_args:
            self._validate_sample(value)
        
        ret = torch.where(value > 0, 
                           (torch.log(1 - self.theta) + self.concentration * torch.log(self.rate) +
                            (self.concentration - 1) * torch.log(value) -
                            self.rate * value - torch.lgamma(self.concentration)), 
                           torch.log(self.theta))
        return ret
    
    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        mask = torch.rand(shape) < self.theta.expand(shape)        
        value = torch._standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
        value[mask] = 0
        value.detach().clamp_(min=torch.finfo(value.dtype).tiny)  # do not record in autograd graph
        return value
    
    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(GammaHurdle2, _instance)
        batch_shape = torch.Size(batch_shape)
        new.concentration = self.concentration.expand(batch_shape)
        new.rate = self.rate.expand(batch_shape)
        new.theta = self.theta.expand(batch_shape)
        super(GammaHurdle2, new).__init__(concentration = new.concentration, rate = new.rate, validate_args=False)
        new._validate_args = self._validate_args
        return new

I switched to batch size of 1 to see if there’s any particular input or label value that’s messing things up. I’m not seeing anything sticking out. I am also printing out the self.loc and self.scale values calculated at each step. My model has 2 features so my understanding is loc and scale will correspond to 2 input features, bias term, shape, and then theta.

Initial values:

,loc,scale
0,-0.21769094,0.099999994
1,-0.76152873,0.099999994
2,0.7582586,0.099999994
3,1.0986123,0.099999994
4,-0.9077194,0.099999994

What I’m seeing happen is that eventually the loc for shape becomes NaN (NaN gets saved to text as blank):

,loc,scale
0,-0.0009057,0.10472898
1,-0.59713113,0.09781907
2,0.98354787,0.0903359
3,,
4,-0.5649243,0.1520296

What should I look for as the cause to that NaN?

I usually look at places like log_prob implementation to see if there is any trouble. You might want to add some print statement before

        ret = torch.where(value > 0, 
                           (torch.log(1 - self.theta) + self.concentration * torch.log(self.rate) +
                            (self.concentration - 1) * torch.log(value) -
                            self.rate * value - torch.lgamma(self.concentration)), 
                           torch.log(self.theta))

to see which values of value/theta/concentration are causing NaN. Then you can use clamp to avoid such extreme values.

Here’s the last output for (value, concentration, rate, theta, theta.dtype):
(tensor([0.]), tensor([2.5642], grad_fn=<ExpandBackward>), tensor([5.2943], grad_fn=<ExpandBackward>), tensor([0.3997], grad_fn=<ExpandBackward>), torch.float64).

All of those are reasonable values. The log_prob function returned -0.9170, again the expected calculation. I saw another thread where setting dtype to float64 fixed his issue. I added torch.set_default_tensor_type(torch.DoubleTensor) and also added .double() for mu, shape, theta, obs=y.double(), and ret. Still same error. That other thread is very similar to my issue - gamma, torch.exp(), and NaN but that fix doesn’t seem to be working for me.

I also clamped log_prob with a min of -5 to see if that would work and that does not either.

Is there a function after log_prob I should look into next?

Here’s another twist. I added .01 to all y so that the new minimum is .01. I then rewrote log_prob to be calculated as
torch.where(value > 0.01, (torch.log1p(-self.theta) + self.concentration * torch.log(self.rate) + (self.concentration - 1) * torch.log(value) - self.rate * value - torch.lgamma(self.concentration)), torch.log(self.theta))

This results in a log probability that is exactly the same as before since I used value > 0.01 as the condition. However, now the model trains successfully and the parameter estimates are close to expected. But why would those changes allow the model to train successfully? Where else are the y observations being used in training?

Just my guess, I think there is no grad at value=0 so the AD system got confused

import torch

value = torch.tensor(0., requires_grad=True)
y = torch.where(value > 0., torch.log(value), torch.log1p(value))
y.backward()
value.grad

Interesting. That makes sense. Do you recommend a different solution? Or is my current solution correct?

Hi @yoshy, I believe you can always guard the bad value in a non-execute branch. For example,

import torch

value = torch.tensor(0., requires_grad=True)
safe_value = torch.where(value > 0., value, torch.tensor(1.))
y = torch.where(value > 0., torch.log(safe_value), torch.log1p(value))
y.backward()
value.grad  # y is the same as before, but AD is happy now

See this note from tfp team

This seems to have fixed the issue! Thank you for all your help

I’m now trying to add calculating a loss metric on a validation set during training.

The model definition is essentially the same as we’ve discussed throughout my posts already. I’ve added an evaluate function. Right now, I’m keeping it simple by just calculating MAE based on the average of the samples.

def evaluate(model, criterion, val_dataloader_iter, validation_steps, device, metric_agg_fn=None):
    model.eval()  # Set model to evaluate mode

    predictive_obs = Predictive(model, guide=guide, num_samples=100, return_sites = ['obs'])
  # statistics
    running_loss = 0.0

  # Iterate over all the validation data.
    for step in range(validation_steps):
        pd_batch = next(val_dataloader_iter)
        pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1).double()    
        inputs = pd_batch['features'].to(device)
        labels = pd_batch[y_name].to(device)
        samples_obs = predictive_obs(inputs)
        loss = np.absolute(torch.mean(samples_obs['obs'], dim = 0) - labels).mean()
        running_loss += loss
            
  # The losses are averaged across observations for each minibatch.
    epoch_loss = running_loss / validation_steps
  
  # metric_agg_fn is used in the distributed training to aggregate the metrics on all workers
    print('Validation Loss: {:.4f} '.format(epoch_loss)) # Func: {:.4f} , loss.item()
    return epoch_loss

def train_one_epoch_SVI(svi,  
                    train_dataloader_iter, steps_per_epoch, epoch, 
                    device):
    running_loss = 0.0
    #iii = 0
    # Iterate over the data for one epoch.
    for step in range(steps_per_epoch):
        pd_batch = next(train_dataloader_iter)
        pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1).double()
        inputs = pd_batch['features'].to(device)
        #labels = pd_batch['pmpm'].to(device)
        labels = pd_batch[y_name].double()# + .01
        labels = labels.to(device)
        loss = svi.step(inputs, labels)
        #iii = iii + 1

        # statistics
        running_loss += loss
    #scheduler.step()

    epoch_loss = running_loss / (steps_per_epoch)
    #epoch_acc = running_corrects.double() / (steps_per_epoch * BATCH_SIZE)

    if (epoch + 1) % 10 == 0: print('Train Loss: {:.4f}'.format(epoch_loss)) #  Func: {:.4f} , loss
    return epoch_loss

def train_and_evaluate_SVI(svi, lr=0.001):
    criterion = Trace_ELBO()#torch.nn.L1Loss()
    #optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Decay LR by a factor of 0.1 every 7 epochs
    #exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(adam, step_size=7, gamma=0.1)

    with data.converter_train.make_torch_dataloader(batch_size=BATCH_SIZE) as train_dataloader, \
       data.converter_val.make_torch_dataloader(batch_size=BATCH_SIZE) as val_dataloader:

        train_dataloader_iter = iter(train_dataloader)
        steps_per_epoch = len(data.converter_train) // BATCH_SIZE

        val_dataloader_iter = iter(val_dataloader)
        validation_steps = max(1, len(data.converter_val) // BATCH_SIZE)

        for epoch in range(NUM_EPOCHS):
            if (epoch + 1) % 10 == 0: print('Epoch {}/{}'.format(epoch + 1, NUM_EPOCHS))
            if (epoch + 1) % 10 == 0: print('-' * 10)

            train_loss = train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)

            if (epoch + 1) % 10 == 0: val_loss = evaluate(model, criterion, val_dataloader_iter, validation_steps, device) #val_loss = 

        #return val_loss

However, while training, the evaluate function is resulting in errors due to the parameters. ValueError: The parameter rate has invalid values. The error occurs with my evaluate function, not the training function. My questions are then, is Predictive the right function to use ? And, if so, am I missing a function parameter that constrains the model fitted parameters for evaluation?

I think for evaluating, it is good to use Predictive. Could you check why the rate parameter has invalid values? In Pyro, I think you don’t need to worry about the unconstrained values, most of them happen under the hood.

Great! thanks for sharing, this is helpful.