Hi All,
I’m a new user to using pyro and wanted to use pyro for sampling from a model I have written within PyTorch. The model itself is a R^N to R^2 function where it takes a vector and outputs the signed-log determinant of a matrix. I can use this model to create a probability distribution by taking the log-abs value and multiple it by 2, this is equivalent to taking the function and squaring it, then taking its log!
I.e. my model returns
sign, logabsdet = net(X)
pdf = logabsdet.mul(2)
#this is equal to
det = sign*torch.exp(logabsdet)
pdf = torch.log(det**2)
I asked a question a few months ago (Can the MCMC modules in Pyro be used sample inputs of a PyTorch model?) about how to sample inputs from a model directly. I.e., how to generate the input X
in the example above via some MCMC method! I’ve followed the answers within this other topic posted above but I get the following error,
"ValueError: Must provide valid initial parameters to begin sampling when using
potential_fn in HMC/NUTS kernel."
net = mynetwork(args...)
net=net.to(device)
hmc_kernel = pyro.infer.mcmc.HMC(potential_fn=net, step_size=1, target_accept_prob=0.5)
mcmc = pyro.infer.MCMC(hmc_kernel, num_samples=1, warmup_steps=10, num_chains=4096) #crashes here
mcmc.run()
samples = mcmc.get_samples()
print(samples)
The full error message is here,
Traceback (most recent call last):
File "run_pyro.py", line 53, in <module>
mcmc = pyro.infer.MCMC(hmc_kernel, num_samples=1, warmup_steps=10, num_chains=4096)
File "~/anaconda3/lib/python3.8/site-packages/pyro/infer/mcmc/api.py", line 476, in __init__
self._validate_kernel(initial_params)
File "~/anaconda3/lib/python3.8/site-packages/pyro/infer/mcmc/api.py", line 389, in _validate_kernel
raise ValueError(
ValueError: Must provide valid initial parameters to begin sampling when using `potential_fn` in HMC/NUTS kernel.
I’ve tried adding a normal dist. to the initial params but it doesn’t work. So I feel like I’m doing a novice mistake here that is quite easily diagnosable!
mcmc = pyro.infer.MCMC(hmc_kernel, num_samples=1, warmup_steps=10, num_chains=4096, initial_params=pyro.distributions.Normal(loc=0, scale=1)) #fails too
Any help would be appreciated! Thank you in advance!