My understanding now after reading the documentation and source code some more is that the zero inflated distribution class only works for distributions that already allow for zeros. I’m now working on writing a custom distribution class.
class GammaHurdle(dist.torch_distribution.TorchDistribution):
# theta: probability of a zero
arg_constraints = {'concentration': torch.distributions.constraints.positive,
'rate': torch.distributions.constraints.positive,
'theta': torch.distributions.constraints.interval(0., 1.)}
support = torch.distributions.constraints._Interval(0, float("inf"))
has_rsample = True
def __init__(self, concentration, rate, theta, validate_args=None):
self.concentration, self.rate, self.theta = broadcast_all(concentration, rate, theta)
if isinstance(concentration, Number) and isinstance(rate, Number):
batch_shape = torch.Size()
else:
batch_shape = self.concentration.size()
super(GammaHurdle, self).__init__(batch_shape, validate_args=validate_args)
def log_prob(self, value):
value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device)
if self._validate_args:
self._validate_sample(value)
return torch.where(value > 0,
(torch.log(1 - self.theta) + self.concentration * torch.log(self.rate) +
(self.concentration - 1) * torch.log(value) -
self.rate * value - torch.lgamma(self.concentration)),
torch.log(self.theta))
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
mask = torch.rand(shape) < self.theta.expand(shape)
value = torch._standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
value[mask] = 0
value.detach().clamp_(min=torch.finfo(value.dtype).tiny) # do not record in autograd graph
return value
The new PyroModule
class:
class BayesianRegression_LogGamma_shape_zeroInf(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(dist.Normal(0., 2.).expand([out_features, in_features]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., 2.).expand([out_features]).to_event(1))
def forward(self, x, y=None):
mu = self.linear(x).squeeze(-1).exp()
shape = pyro.sample("shape", dist.Gamma(.01, .01))#.exp()#self.linear_shape(x).squeeze(-1).exp()
theta = pyro.sample("theta", dist.Uniform(.000001, 1.0))#.exp()
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y)
return mu# torch.cat((mu, shape), 0)
I’m backing to getting an error: ValueError: The parameter loc has invalid values
. Again, testing the code while adding a positive amount to any zeros in y results in the model training. I am currently stuck figuring out why loc
is invalid though.
Full error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_messenger.py in __call__(self, *args, **kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError) as e:
~\AppData\Roaming\Python\Python37\site-packages\pyro\nn\module.py in __call__(self, *args, **kwargs)
412 with self._pyro_context:
--> 413 return super().__call__(*args, **kwargs)
414
C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in forward(self, *args, **kwargs)
725
--> 726 latent = self.sample_latent(*args, **kwargs)
727 plates = self._create_plates(*args, **kwargs)
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in sample_latent(self, *args, **kwargs)
687 """
--> 688 pos_dist = self.get_posterior(*args, **kwargs)
689 return pyro.sample("_{}_latent".format(self._pyro_name), pos_dist, infer={"is_auxiliary": True})
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in get_posterior(self, *args, **kwargs)
903 """
--> 904 return dist.Normal(self.loc, self.scale).to_event(1)
905
~\AppData\Roaming\Python\Python37\site-packages\pyro\distributions\distribution.py in __call__(cls, *args, **kwargs)
17 return result
---> 18 return super().__call__(*args, **kwargs)
19
C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\distributions\normal.py in __init__(self, loc, scale, validate_args)
49 batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)
51
C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\distributions\distribution.py in __init__(self, batch_shape, event_shape, validate_args)
52 if not constraint.check(getattr(self, param)).all():
---> 53 raise ValueError("The parameter {} has invalid values".format(param))
54 super(Distribution, self).__init__()
ValueError: The parameter loc has invalid values
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
<timed eval> in <module>
<ipython-input-26-ca50fa2dbbbe> in train_and_evaluate_SVI(svi, lr)
46 if (epoch + 1) % 100 == 0: print('-' * 10)
47
---> 48 train_loss = train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)
49
50 #val_loss = evaluate(model, criterion, val_dataloader_iter, validation_steps, device)
<ipython-input-26-ca50fa2dbbbe> in train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)
14
15 # Track history in training
---> 16 loss = svi.step(inputs, labels)
17
18 # statistics
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\svi.py in step(self, *args, **kwargs)
126 # get loss and compute gradients
127 with poutine.trace(param_only=True) as param_capture:
--> 128 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
129
130 params = set(site["value"].unconstrained()
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
129 loss = 0.0
130 # grab a trace from the generator
--> 131 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
132 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
133 loss += loss_particle / self.num_particles
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\elbo.py in _get_traces(self, model, guide, args, kwargs)
168 else:
169 for i in range(self.num_particles):
--> 170 yield self._get_trace(model, guide, args, kwargs)
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
56 """
57 model_trace, guide_trace = get_importance_trace(
---> 58 "flat", self.max_plate_nesting, model, guide, args, kwargs)
59 if is_validation_enabled():
60 check_if_enumerated(guide_trace)
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
42 and the model that is run against it.
43 """
---> 44 guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
45 if detach:
46 guide_trace.detach_()
~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_messenger.py in get_trace(self, *args, **kwargs)
185 Calls this poutine and returns its trace instead of the function's return value.
186 """
--> 187 self(*args, **kwargs)
188 return self.msngr.get_trace()
~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_messenger.py in __call__(self, *args, **kwargs)
169 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
170 exc = exc.with_traceback(traceback)
--> 171 raise exc from e
172 self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
173 return ret
~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_messenger.py in __call__(self, *args, **kwargs)
163 args=args, kwargs=kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError) as e:
167 exc_type, exc_value, traceback = sys.exc_info()
~\AppData\Roaming\Python\Python37\site-packages\pyro\nn\module.py in __call__(self, *args, **kwargs)
411 def __call__(self, *args, **kwargs):
412 with self._pyro_context:
--> 413 return super().__call__(*args, **kwargs)
414
415 def __getattr__(self, name):
C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in forward(self, *args, **kwargs)
724 self._setup_prototype(*args, **kwargs)
725
--> 726 latent = self.sample_latent(*args, **kwargs)
727 plates = self._create_plates(*args, **kwargs)
728
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in sample_latent(self, *args, **kwargs)
686 base ``model``.
687 """
--> 688 pos_dist = self.get_posterior(*args, **kwargs)
689 return pyro.sample("_{}_latent".format(self._pyro_name), pos_dist, infer={"is_auxiliary": True})
690
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\autoguide\guides.py in get_posterior(self, *args, **kwargs)
902 Returns a diagonal Normal posterior distribution.
903 """
--> 904 return dist.Normal(self.loc, self.scale).to_event(1)
905
906 def _loc_scale(self, *args, **kwargs):
~\AppData\Roaming\Python\Python37\site-packages\pyro\distributions\distribution.py in __call__(cls, *args, **kwargs)
16 if result is not None:
17 return result
---> 18 return super().__call__(*args, **kwargs)
19
20 @property
C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\distributions\normal.py in __init__(self, loc, scale, validate_args)
48 else:
49 batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)
51
52 def expand(self, batch_shape, _instance=None):
C:\ProgramData\Anaconda3\envs\ff\lib\site-packages\torch\distributions\distribution.py in __init__(self, batch_shape, event_shape, validate_args)
51 continue # skip checking lazily-constructed args
52 if not constraint.check(getattr(self, param)).all():
---> 53 raise ValueError("The parameter {} has invalid values".format(param))
54 super(Distribution, self).__init__()
55
ValueError: The parameter loc has invalid values
Trace Shapes:
Param Sites:
AutoDiagonalNormal.loc 5
AutoDiagonalNormal.scale 5
Sample Sites: