I’m a new user to using pyro and wanted to use pyro for sampling from a model I have written within PyTorch. The model itself is a R^N to R^2 function where it takes a vector and outputs the signed-log determinant of a matrix. I can use this model to create a probability distribution by taking the log-abs value and multiple it by 2, this is equivalent to taking the function and squaring it, then taking its log!
I.e. my model returns
sign, logabsdet = net(X) pdf = logabsdet.mul(2) #this is equal to det = sign*torch.exp(logabsdet) pdf = torch.log(det**2)
I asked a question a few months ago (Can the MCMC modules in Pyro be used sample inputs of a PyTorch model?) about how to sample inputs from a model directly. I.e., how to generate the input
X in the example above via some MCMC method! I’ve followed the answers within this other topic posted above but I get the following error,
"ValueError: Must provide valid initial parameters to begin sampling when using potential_fn
in HMC/NUTS kernel."
net = mynetwork(args...) net=net.to(device) hmc_kernel = pyro.infer.mcmc.HMC(potential_fn=net, step_size=1, target_accept_prob=0.5) mcmc = pyro.infer.MCMC(hmc_kernel, num_samples=1, warmup_steps=10, num_chains=4096) #crashes here mcmc.run() samples = mcmc.get_samples() print(samples)
The full error message is here,
Traceback (most recent call last): File "run_pyro.py", line 53, in <module> mcmc = pyro.infer.MCMC(hmc_kernel, num_samples=1, warmup_steps=10, num_chains=4096) File "~/anaconda3/lib/python3.8/site-packages/pyro/infer/mcmc/api.py", line 476, in __init__ self._validate_kernel(initial_params) File "~/anaconda3/lib/python3.8/site-packages/pyro/infer/mcmc/api.py", line 389, in _validate_kernel raise ValueError( ValueError: Must provide valid initial parameters to begin sampling when using `potential_fn` in HMC/NUTS kernel.
I’ve tried adding a normal dist. to the initial params but it doesn’t work. So I feel like I’m doing a novice mistake here that is quite easily diagnosable!
mcmc = pyro.infer.MCMC(hmc_kernel, num_samples=1, warmup_steps=10, num_chains=4096, initial_params=pyro.distributions.Normal(loc=0, scale=1)) #fails too
Any help would be appreciated! Thank you in advance!