I’m trying to run Predictive
with parallel=True
on a neural network (again the problem I’ve posted about before). I think I’m close to having the shapes right for parallel inference but I’m stuck.
Based on searching the forum, I’m using my own nn.Linear
class. This is designed to handle shapes properly.
class myLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True, _print=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.bias = bias
self._print = _print
self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
self.bias = torch.nn.Parameter(torch.randn(out_features))
def forward(self, input):
x, y = input.shape
if y != self.in_features:
sys.exit(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
output = input @ self.weight.squeeze().t() + self.bias.squeeze()
return output
I get an error with Predictive
.
predictive_obs = Predictive(model, guide=guide, num_samples=int(12), return_sites = ['obs'], parallel = True)
with converter_val.make_torch_dataloader(batch_size=5) as val_dataloader:
val_dataloader_iter = iter(val_dataloader)
validation_steps = 1
# Iterate over all the validation data.
for step in range(validation_steps):
pd_batch = next(val_dataloader_iter)
pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)#.double()
inputs = pd_batch['features'].to(device)
labels = pd_batch[y_name].to(device)
samples_obs = predictive_obs(inputs, subsample=False)
ValueError: Shape mismatch inside plate('data') at site obs dim -1, 5 vs 12
For a batch size of 5 with 12 samples, the shape I get for my parameters such as mu
is [5, 12] = [batch size, samples]
. Is the output of my linear module correct? And the error is in plate somewhere?
Full error output:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
13
/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
13
/databricks/python/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
13
/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
425 with self._pyro_context:
--> 426 return super().__call__(*args, **kwargs)
427
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
<command-3701536137140920> in forward(self, x, y, subsample)
118 with pyro.plate("data", x.shape[0], device = device, dim=-1):
--> 119 obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y) #.to_event(1)
120 #print(obs.shape)
/databricks/python/lib/python3.8/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
162 # apply the stack and return its return value
--> 163 apply_stack(msg)
164 return msg["value"]
/databricks/python/lib/python3.8/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
212
--> 213 frame._process_message(msg)
214
/databricks/python/lib/python3.8/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
18 super()._process_message(msg)
---> 19 return BroadcastMessenger._pyro_sample(msg)
20
/usr/lib/python3.8/contextlib.py in inner(*args, **kwds)
74 with self._recreate_cm():
---> 75 return func(*args, **kwds)
76 return inner
/databricks/python/lib/python3.8/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
64 ):
---> 65 raise ValueError(
66 "Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
ValueError: Shape mismatch inside plate('data') at site obs dim -1, 5 vs 12
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
<command-3701536137140873> in <module>
11 inputs = pd_batch['features'].to(device)
12 labels = pd_batch[y_name].to(device)
---> 13 samples_obs = predictive_obs(inputs, subsample=False)
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/databricks/python/lib/python3.8/site-packages/pyro/infer/predictive.py in forward(self, *args, **kwargs)
271 model_kwargs=kwargs,
272 )
--> 273 return _predictive(
274 self.model,
275 posterior_samples,
/databricks/python/lib/python3.8/site-packages/pyro/infer/predictive.py in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
135 )
136
--> 137 trace = poutine.trace(
138 poutine.condition(vectorize(model), reshaped_samples)
139 ).get_trace(*model_args, **model_kwargs)
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
196 Calls this poutine and returns its trace instead of the function's return value.
197 """
--> 198 self(*args, **kwargs)
199 return self.msngr.get_trace()
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
178 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
179 exc = exc.with_traceback(traceback)
--> 180 raise exc from e
181 self.msngr.trace.add_node(
182 "_RETURN", name="_RETURN", type="return", value=ret
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
172 )
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176 exc_type, exc_value, traceback = sys.exc_info()
/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
/databricks/python/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30
/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
424 def __call__(self, *args, **kwargs):
425 with self._pyro_context:
--> 426 return super().__call__(*args, **kwargs)
427
428 def __getattr__(self, name):
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
<command-3701536137140920> in forward(self, x, y, subsample)
117 else:
118 with pyro.plate("data", x.shape[0], device = device, dim=-1):
--> 119 obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y) #.to_event(1)
120 #print(obs.shape)
121 # if SUBSAMPLE_SVI_N is None:
/databricks/python/lib/python3.8/site-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
161 }
162 # apply the stack and return its return value
--> 163 apply_stack(msg)
164 return msg["value"]
165
/databricks/python/lib/python3.8/site-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
211 pointer = pointer + 1
212
--> 213 frame._process_message(msg)
214
215 if msg["stop"]:
/databricks/python/lib/python3.8/site-packages/pyro/poutine/plate_messenger.py in _process_message(self, msg)
17 def _process_message(self, msg):
18 super()._process_message(msg)
---> 19 return BroadcastMessenger._pyro_sample(msg)
20
21 def __enter__(self):
/usr/lib/python3.8/contextlib.py in inner(*args, **kwds)
73 def inner(*args, **kwds):
74 with self._recreate_cm():
---> 75 return func(*args, **kwds)
76 return inner
77
/databricks/python/lib/python3.8/site-packages/pyro/poutine/broadcast_messenger.py in _pyro_sample(msg)
63 and target_batch_shape[f.dim] != f.size
64 ):
---> 65 raise ValueError(
66 "Shape mismatch inside plate('{}') at site {} dim {}, {} vs {}".format(
67 f.name,
ValueError: Shape mismatch inside plate('data') at site obs dim -1, 5 vs 12
Trace Shapes:
Param Sites:
mu.mu_fc0.weight 8 2
mu.mu_fc0.bias 8
mu.mu_fc1.weight 8 8
mu.mu_fc1.bias 8
shape.sh_fc0.weight 8 2
shape.sh_fc0.bias 8
shape.sh_fc1.weight 8 8
shape.sh_fc1.bias 8
theta.th_fc0.weight 8 2
theta.th_fc0.bias 8
theta.th_fc1.weight 8 8
theta.th_fc1.bias 8
Sample Sites:
mu.mu_fc2.weight dist 12 1 | 1 8
value 12 | 1 8
mu.mu_fc2.bias dist 12 1 | 1
value 12 | 1
shape.sh_fc2.weight dist 12 1 | 1 8
value 12 | 1 8
shape.sh_fc2.bias dist 12 1 | 1
value 12 | 1
theta.th_fc2.weight dist 12 1 | 1 8
value 12 | 1 8
theta.th_fc2.bias dist 12 1 | 1
value 12 | 1
data dist |
value 5 |
Here is my model class (I’ve posted it in previous questions as well). I’ll edit things down for brevity.
class BayesNN_Gamma_muShapeTheta(PyroModule):
#
@name_count
def __init__(self, in_features, mu_l, sh_l, th_l, out_features = 1,unif_bound = 3.):
super().__init__()
# parameter names list
self.parameter_names = []
layers = []
for i in range(len(mu_l)-1):
layers.append(('mu_fc' + str(i), myLinear(mu_l[i], mu_l[i+1], _print=True)))
if i != (len(mu_l)-2): layers.append(('mu_ReLU' + str(i), nn.ReLU()))
mu = OrderedDict(layers)
self.mu = nn.Sequential(mu)
for name, param in self.mu.named_parameters():
self.parameter_names.append(name)
pyro.nn.module.to_pyro_module_(self.mu)
# following step sets only the final layer to Bayesian
# the lower layers remain constants as in a regular neural net
for m in self.mu.modules():
if m._pyro_name == 'mu.mu_fc' + str(len(mu_l)-2):
for name, value in list(m.named_parameters(recurse=False)):
setattr(m, name, PyroSample(prior=dist.Normal(0., 1.).expand(value.shape).to_event(value.dim())))
# shape and theta are basically the same
...
def forward(self, x, y=None, subsample = True):
mu = self.mu(x).exp().clamp(min = .000001, max = 1e6)
shape = self.shape(x).exp().clamp(min = 0.000001, max = 1e2)
theta = self.theta(x)
with pyro.plate("data", x.shape[0], device = device, dim=-1):
obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y)
return torch.cat((mu, shape, theta), 0)