I’d like to use the ProjectedNormal distribution in a MCMC kernel. This doesn’t work with HMC or NUTS:
n_samples = 100
data = torch.randn(n_samples)
args_dict = {'data':data}
def model(args_dict):
data = args_dict['data']
x1 = pyro.sample('x1', dist.ProjectedNormal(tensor([0.0,1.0])))
angle = torch.atan2(x1[0],x1[1]).rad2deg()
x2 = pyro.sample('x2', dist.Normal(x1,1), obs=data)
return x2
from pyro.infer.mcmc import HMC
kernel = HMC(model)
mcmc = MCMC(kernel, num_samples=500)
mcmc.run(args_dict)
Below is the error trace back
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
/var/folders/bg/cb0cr7ls61352lhy50167r0c0000gn/T/ipykernel_5114/2608258424.py in <module>
13 kernel = HMC(model)
14 mcmc = MCMC(kernel, num_samples=500)
---> 15 mcmc.run(args_dict)
~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
561 # requires_grad", which happens with `jit_compile` under PyTorch 1.7
562 args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
--> 563 for x, chain_id in self.sampler.run(*args, **kwargs):
564 if num_samples[chain_id] == 0:
565 num_samples[chain_id] += 1
~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
228 i if self.num_chains > 1 else None,
229 *args,
--> 230 **kwargs
231 ):
232 yield sample, i # sample, chain_id
~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
142
143 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 144 kernel.setup(warmup_steps, *args, **kwargs)
145 params = kernel.initial_params
146 save_params = getattr(kernel, "save_params", sorted(params))
~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
323 self._warmup_steps = warmup_steps
324 if self.model is not None:
--> 325 self._initialize_model_properties(args, kwargs)
326 if self.initial_params:
327 z = {k: v.detach() for k, v in self.initial_params.items()}
~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
267 skip_jit_warnings=self._ignore_jit_warnings,
268 init_strategy=self._init_strategy,
--> 269 initial_params=self._initial_params,
270 )
271 self.potential_fn = potential_fn
~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains, init_strategy, initial_params)
451 prototype_samples[name] = node["value"].detach()
452 if automatic_transform_enabled:
--> 453 transforms[name] = biject_to(node["fn"].support).inv
454
455 trace_prob_evaluator = TraceEinsumEvaluator(
~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/torch/distributions/constraint_registry.py in __call__(self, constraint)
141 except KeyError:
142 raise NotImplementedError(
--> 143 f'Cannot transform {type(constraint).__name__} constraints') from None
144 return factory(constraint)
145
NotImplementedError: Cannot transform _Sphere constraints