I have a fairly simple multivariate regression model:
def model(x, y=None):
β = pyro.sample("β", dist.Normal(torch.zeros(x.shape[1],2),torch.ones(x.shape[1],2)))
μ = torch.matmul(x,β)
θ = pyro.sample("θ", dist.HalfCauchy(torch.ones(x.shape[1])))
L = pyro.sample("L", dist.LKJCorrCholesky(x.shape[1], torch.tensor(1.)))
L_Ω = torch.mm(torch.diag(θ.sqrt()), L)
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.MultivariateNormal(loc=μ, scale_tril=L_Ω), obs=y)
nuts_kernel = NUTS(model, jit_compile=False, step_size=1e-5)
mcmc = MCMC(nuts_kernel, num_samples=500,
warmup_steps=100, num_chains=1, disable_progbar=False)
mcmc.run(x_train,y_train)
When I run this on a few hundred samples or less it seems to be ok (although slow). However, when I try to run on the full dataset (1380 samples), I get the following error immediately:
Warmup: 0%| | 0/600 [00:00, ?it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-86-7e30656d1911> in <module>
3 mcmc = MCMC(nuts_kernel, num_samples=500,
4 warmup_steps=100, num_chains=1, disable_progbar=False)
----> 5 mcmc.run(x_train,y_train)
~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
380 # requires_grad", which happens with `jit_compile` under PyTorch 1.7
381 args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
--> 382 for x, chain_id in self.sampler.run(*args, **kwargs):
383 if num_samples[chain_id] == 0:
384 num_samples[chain_id] += 1
~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
164 for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,
165 i if self.num_chains > 1 else None,
--> 166 *args, **kwargs):
167 yield sample, i # sample, chain_id
168 self.kernel.cleanup()
~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
106
107 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 108 kernel.setup(warmup_steps, *args, **kwargs)
109 params = kernel.initial_params
110 # yield structure (key, value.shape) of params
~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
301 self._warmup_steps = warmup_steps
302 if self.model is not None:
--> 303 self._initialize_model_properties(args, kwargs)
304 if self.initial_params:
305 z = {k: v.detach() for k, v in self.initial_params.items()}
~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
253 skip_jit_warnings=self._ignore_jit_warnings,
254 init_strategy=self._init_strategy,
--> 255 initial_params=self._initial_params,
256 )
257 self.potential_fn = potential_fn
~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains, init_strategy, initial_params)
419 pe_maker.get_potential_fn(), prototype_params,
420 num_chains=num_chains, init_strategy=init_strategy,
--> 421 trace=model_trace)
422 potential_fn = pe_maker.get_potential_fn(jit_compile, skip_jit_warnings, jit_options)
423 return initial_params, potential_fn, transforms, model_trace
~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in _find_valid_initial_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params, max_tries_initial_params, num_chains, init_strategy, trace)
330 return {k: torch.stack(v) for k, v in params_per_chain.items()}
331 trace = None
--> 332 raise ValueError("Model specification seems incorrect - cannot find valid initial params.")
333
334
ValueError: Model specification seems incorrect - cannot find valid initial params.
Sample: 18%|█▊ | 108/600 [00:21, 4.57it/s, step size=3.86e-03, acc. prob=0.994]