NotImplementedError: Cannot transform _Sphere constraints

I’d like to use the ProjectedNormal distribution in a MCMC kernel. This doesn’t work with HMC or NUTS:

n_samples = 100
data = torch.randn(n_samples)
args_dict = {'data':data}

def model(args_dict):
    data = args_dict['data']
    x1 = pyro.sample('x1', dist.ProjectedNormal(tensor([0.0,1.0])))
    angle = torch.atan2(x1[0],x1[1]).rad2deg()
    x2 = pyro.sample('x2', dist.Normal(x1,1), obs=data)
    return x2

from pyro.infer.mcmc import HMC
kernel = HMC(model)
mcmc = MCMC(kernel, num_samples=500)
mcmc.run(args_dict)

Below is the error trace back

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
/var/folders/bg/cb0cr7ls61352lhy50167r0c0000gn/T/ipykernel_5114/2608258424.py in <module>
     13 kernel = HMC(model)
     14 mcmc = MCMC(kernel, num_samples=500)
---> 15 mcmc.run(args_dict)

~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    561             # requires_grad", which happens with `jit_compile` under PyTorch 1.7
    562             args = [arg.detach() if torch.is_tensor(arg) else arg for arg in args]
--> 563             for x, chain_id in self.sampler.run(*args, **kwargs):
    564                 if num_samples[chain_id] == 0:
    565                     num_samples[chain_id] += 1

~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    228                 i if self.num_chains > 1 else None,
    229                 *args,
--> 230                 **kwargs
    231             ):
    232                 yield sample, i  # sample, chain_id

~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
    142 
    143 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 144     kernel.setup(warmup_steps, *args, **kwargs)
    145     params = kernel.initial_params
    146     save_params = getattr(kernel, "save_params", sorted(params))

~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    323         self._warmup_steps = warmup_steps
    324         if self.model is not None:
--> 325             self._initialize_model_properties(args, kwargs)
    326         if self.initial_params:
    327             z = {k: v.detach() for k, v in self.initial_params.items()}

~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
    267             skip_jit_warnings=self._ignore_jit_warnings,
    268             init_strategy=self._init_strategy,
--> 269             initial_params=self._initial_params,
    270         )
    271         self.potential_fn = potential_fn

~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/pyro/infer/mcmc/util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains, init_strategy, initial_params)
    451         prototype_samples[name] = node["value"].detach()
    452         if automatic_transform_enabled:
--> 453             transforms[name] = biject_to(node["fn"].support).inv
    454 
    455     trace_prob_evaluator = TraceEinsumEvaluator(

~/miniconda2/envs/phys_aware_cryoem_202207/lib/python3.7/site-packages/torch/distributions/constraint_registry.py in __call__(self, constraint)
    141         except KeyError:
    142             raise NotImplementedError(
--> 143                 f'Cannot transform {type(constraint).__name__} constraints') from None
    144         return factory(constraint)
    145 

NotImplementedError: Cannot transform _Sphere constraints

You should be able to reparametrize the model and run HMC in a higher-dimensional space, e.g.

@poutine.reparam(config={"x1": ProjectedNormalReparam()})
def model(args_dict):
    x1 = pyro.sample("xw", ProjectedNormal(torch.zeros(3)))
    ...

You should also be able to just use AutoReparam:

model = ...
reparam_model = AutoReparam()(model)

Let me know if you have trouble, we definitely want to make this work :smile:

1 Like