ValueError when data size surpasses some amount

I have a fairly simple multivariate regression model:

def model(x, y=None):
    β = pyro.sample("β", dist.Normal(torch.zeros(x.shape[1],2),torch.ones(x.shape[1],2)))
    μ = torch.matmul(x,β)
    θ = pyro.sample("θ", dist.HalfCauchy(torch.ones(x.shape[1])))
    L = pyro.sample("L", dist.LKJCorrCholesky(x.shape[1], torch.tensor(1.)))
    L_Ω = torch.mm(torch.diag(θ.sqrt()), L)
    with pyro.plate("data", x.shape[0]):
        obs = pyro.sample("obs", dist.MultivariateNormal(loc=μ, scale_tril=L_Ω), obs=y)

nuts_kernel = NUTS(model, jit_compile=False, step_size=1e-5)
mcmc = MCMC(nuts_kernel, num_samples=500, 
     warmup_steps=100, num_chains=1, disable_progbar=False)
mcmc.run(x_train,y_train)

When I run this on a few hundred samples or less it seems to be ok (although slow). However, when I try to run on the full dataset (1380 samples), I get the following error immediately:

Warmup:   0%|          | 0/600 [00:00, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-86-7e30656d1911> in <module>
      3 mcmc = MCMC(nuts_kernel, num_samples=500, 
      4      warmup_steps=100, num_chains=1, disable_progbar=False)
----> 5 mcmc.run(x_train,y_train)

~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    380             # requires_grad", which happens with `jit_compile` under PyTorch 1.7
    381             args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
--> 382             for x, chain_id in self.sampler.run(*args, **kwargs):
    383                 if num_samples[chain_id] == 0:
    384                     num_samples[chain_id] += 1

~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    164             for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,
    165                                        i if self.num_chains > 1 else None,
--> 166                                        *args, **kwargs):
    167                 yield sample, i  # sample, chain_id
    168             self.kernel.cleanup()

~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
    106 
    107 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 108     kernel.setup(warmup_steps, *args, **kwargs)
    109     params = kernel.initial_params
    110     # yield structure (key, value.shape) of params

~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    301         self._warmup_steps = warmup_steps
    302         if self.model is not None:
--> 303             self._initialize_model_properties(args, kwargs)
    304         if self.initial_params:
    305             z = {k: v.detach() for k, v in self.initial_params.items()}

~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
    253             skip_jit_warnings=self._ignore_jit_warnings,
    254             init_strategy=self._init_strategy,
--> 255             initial_params=self._initial_params,
    256         )
    257         self.potential_fn = potential_fn

~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains, init_strategy, initial_params)
    419                                                     pe_maker.get_potential_fn(), prototype_params,
    420                                                     num_chains=num_chains, init_strategy=init_strategy,
--> 421                                                     trace=model_trace)
    422     potential_fn = pe_maker.get_potential_fn(jit_compile, skip_jit_warnings, jit_options)
    423     return initial_params, potential_fn, transforms, model_trace

~/pyro/multivariate/.pyro/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in _find_valid_initial_params(model, model_args, model_kwargs, transforms, potential_fn, prototype_params, max_tries_initial_params, num_chains, init_strategy, trace)
    330                     return {k: torch.stack(v) for k, v in params_per_chain.items()}
    331         trace = None
--> 332     raise ValueError("Model specification seems incorrect - cannot find valid initial params.")
    333 
    334 

ValueError: Model specification seems incorrect - cannot find valid initial params.
Sample:  18%|█▊        | 108/600 [00:21,  4.57it/s, step size=3.86e-03, acc. prob=0.994]

@thecity2 To debug, you can add some print statements to your model, and see if:

  • parameters get wrong values
  • data get wrong values
  • after get parameters/values, try a separate
value = torch.tensor(..., requires_grad=True)
dist.Foo(params).log_prob(value).backward()
print(value.grad)

to see if grad get wrong values

Thank you @fehiepsi. In the end it turned out there was a lone NaN hiding among my data. Once I removed that, I had no issues.