Nans from svi with ProjectedNormal likelihood

I’m able to converge with a model using a MultivariateNormal likelihood. However, when I use a ProjecteNormal likelihood, I get nans quite quickly.

I am also not able to use my VS Code debgger with justMyCode false, to enter pyro source and see where exactly the nans appear.

Here is my setup:

from scipy.spatial.transform import Rotation as R
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
import kornia.geometry.conversions as conversions
from pose import  sample_small_rotation


def main(option):

    # data
    dtype = torch.float32
    n_sym = 4
    n_samples = n_sym*7
    R_i = torch.from_numpy(R.random(n_samples).as_matrix()).to(dtype)
    np.random.seed(0)
    G_gt = torch.from_numpy(R.random().as_matrix()).unsqueeze(0).to(dtype)
    R_ei = torch.from_numpy(sample_small_rotation(n_samples, 10)).to(dtype)
    angles_sym = np.zeros((n_samples,3))
    angles_unique = [0, 360//n_sym, (360*2)//n_sym,(360*3)//n_sym]
    angles_sym[:,0] = np.repeat(angles_unique, n_samples//n_sym)
    R_symi = torch.from_numpy(R.from_euler('zxz', angles_sym, degrees=True).as_matrix()).to(dtype)
    R_axis_gt = torch.from_numpy(R.random().as_matrix()).unsqueeze(0).to(dtype)
    R_j = G_gt @ R_ei @ R_i @ R_axis_gt @ R_symi @ R_axis_gt.transpose(1,2)
    data = {
        'R_i': R_i,
        'R_j': R_j,
        'R_symi': R_symi,
    }


    # models and guide
    def pn_model(data):
        uniform_pn = dist.ProjectedNormal(torch.zeros(4, dtype=dtype))
        c = 10.0
        small_error_pn = dist.ProjectedNormal(c*torch.tensor([0,0,0,1.0], dtype=dtype))
        q_g = pyro.sample('global_rotation', uniform_pn).to(dtype)
        R_g = conversions.quaternion_to_rotation_matrix(q_g).unsqueeze(0)
        q_axis = pyro.sample('axis_rotation', uniform_pn).to(dtype)
        R_axis = conversions.quaternion_to_rotation_matrix(q_axis).unsqueeze(0)
        R_i = data['R_i']
        R_j = data['R_j']
        R_symi = data['R_symi']
        
        with pyro.plate('data', len(R_i)):
            R_ei_hat = R_g.transpose(1,2) @ R_j @ R_axis @ (R_i @ R_axis @ R_symi).transpose(1,2)
            q_ei_hat = conversions.rotation_matrix_to_quaternion(R_ei_hat).to(dtype)
            rotation_residuals = pyro.sample('rotation_residual', small_error_pn, obs = q_ei_hat.to(dtype))
            print(rotation_residuals)

    def mvn_model(data):
        uniform_pn = dist.ProjectedNormal(torch.zeros(4, dtype=dtype))
        small_error_n = dist.MultivariateNormal(torch.eye(3).reshape(9), covariance_matrix = torch.eye(9))
        q_g = pyro.sample('global_rotation', uniform_pn).to(dtype)
        R_g = conversions.quaternion_to_rotation_matrix(q_g).unsqueeze(0)
        q_axis = pyro.sample('axis_rotation', uniform_pn).to(dtype)
        R_axis = conversions.quaternion_to_rotation_matrix(q_axis).unsqueeze(0)
        R_i = data['R_i']
        R_j = data['R_j']
        R_symi = data['R_symi']
        
        with pyro.plate('data', len(R_i)):
            R_ei_hat = R_g.transpose(1,2) @ R_j @ R_axis @ (R_i @ R_axis @ R_symi).transpose(1,2)
            pyro.sample('rotation_residual', small_error_n, obs = R_ei_hat.reshape(len(R_ei_hat),9).to(dtype))
        
        
    def guide(data):
        global_concentration = pyro.param("global_concentration", torch.tensor([0.5,0.13,0.32,0.73], dtype=dtype))
        pyro.sample('global_rotation', dist.ProjectedNormal(global_concentration))
        
        axis_concentration = pyro.param("axis_concentration", torch.tensor([-0.0,  0.8, -0.4,  0.4], dtype=dtype))
        pyro.sample('axis_rotation', dist.ProjectedNormal(axis_concentration))
        

    # SVI setup
    adam_params = {"lr": 0.0001, "clip_norm": 10.0}
    optimizer = pyro.optim.ClippedAdam(adam_params)

    if option == 'mvn':
        model = mvn_model
    elif option == 'pn':
        model = pn_model
    else:
        raise ValueError('option must be pn or mvn')
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    losses = []

    n_steps = 100
    ## do gradient steps
    for step in range(n_steps):
        if step % 10 == 0:
            print(step)
        losses.append(svi.step(data))

if __name__ == '__main__':
    main(option='pn')

This works: main(option='mvn').

However this does not work: main(option='pn')

This is the error trace (you should be able to reproduce)

(datapy) (base) [gwoollard@workergpu083 latent_information_content]$ python pose_symmetry.py 
0
10
Traceback (most recent call last):
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/mnt/ceph/users/gwoollard/repos/latent_information_content/pose_symmetry.py", line 70, in guide
    pyro.sample('global_rotation', dist.ProjectedNormal(global_concentration))
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/distributions/distribution.py", line 24, in __call__
    return super().__call__(*args, **kwargs)
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/distributions/projected_normal.py", line 61, in __init__
    super().__init__(batch_shape, event_shape, validate_args=validate_args)
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/torch/distributions/distribution.py", line 68, in __init__
    raise ValueError(
ValueError: Expected parameter concentration (Tensor of shape (4,)) of distribution ProjectedNormal(concentration: torch.Size([4])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([nan, nan, nan, nan], requires_grad=True)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/ceph/users/gwoollard/repos/latent_information_content/pose_symmetry.py", line 98, in <module>
    main(option='pn')
  File "/mnt/ceph/users/gwoollard/repos/latent_information_content/pose_symmetry.py", line 95, in main
    losses.append(svi.step(data))
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/infer/svi.py", line 145, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/infer/trace_elbo.py", line 140, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/infer/elbo.py", line 237, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/infer/trace_elbo.py", line 57, in _get_trace
    model_trace, guide_trace = get_importance_trace(
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/infer/enum.py", line 60, in get_importance_trace
    guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
    raise exc from e
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/mnt/ceph/users/gwoollard/repos/latent_information_content/pose_symmetry.py", line 70, in guide
    pyro.sample('global_rotation', dist.ProjectedNormal(global_concentration))
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/distributions/distribution.py", line 24, in __call__
    return super().__call__(*args, **kwargs)
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/distributions/projected_normal.py", line 61, in __init__
    super().__init__(batch_shape, event_shape, validate_args=validate_args)
  File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/torch/distributions/distribution.py", line 68, in __init__
    raise ValueError(
ValueError: Expected parameter concentration (Tensor of shape (4,)) of distribution ProjectedNormal(concentration: torch.Size([4])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([nan, nan, nan, nan], requires_grad=True)
       Trace Shapes:  
        Param Sites:  
global_concentration 4
       Sample Sites: