I’m able to converge with a model using a MultivariateNormal likelihood. However, when I use a ProjecteNormal likelihood, I get nans quite quickly.
I am also not able to use my VS Code debgger with justMyCode false, to enter pyro source and see where exactly the nans appear.
Here is my setup:
from scipy.spatial.transform import Rotation as R
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
import kornia.geometry.conversions as conversions
from pose import sample_small_rotation
def main(option):
# data
dtype = torch.float32
n_sym = 4
n_samples = n_sym*7
R_i = torch.from_numpy(R.random(n_samples).as_matrix()).to(dtype)
np.random.seed(0)
G_gt = torch.from_numpy(R.random().as_matrix()).unsqueeze(0).to(dtype)
R_ei = torch.from_numpy(sample_small_rotation(n_samples, 10)).to(dtype)
angles_sym = np.zeros((n_samples,3))
angles_unique = [0, 360//n_sym, (360*2)//n_sym,(360*3)//n_sym]
angles_sym[:,0] = np.repeat(angles_unique, n_samples//n_sym)
R_symi = torch.from_numpy(R.from_euler('zxz', angles_sym, degrees=True).as_matrix()).to(dtype)
R_axis_gt = torch.from_numpy(R.random().as_matrix()).unsqueeze(0).to(dtype)
R_j = G_gt @ R_ei @ R_i @ R_axis_gt @ R_symi @ R_axis_gt.transpose(1,2)
data = {
'R_i': R_i,
'R_j': R_j,
'R_symi': R_symi,
}
# models and guide
def pn_model(data):
uniform_pn = dist.ProjectedNormal(torch.zeros(4, dtype=dtype))
c = 10.0
small_error_pn = dist.ProjectedNormal(c*torch.tensor([0,0,0,1.0], dtype=dtype))
q_g = pyro.sample('global_rotation', uniform_pn).to(dtype)
R_g = conversions.quaternion_to_rotation_matrix(q_g).unsqueeze(0)
q_axis = pyro.sample('axis_rotation', uniform_pn).to(dtype)
R_axis = conversions.quaternion_to_rotation_matrix(q_axis).unsqueeze(0)
R_i = data['R_i']
R_j = data['R_j']
R_symi = data['R_symi']
with pyro.plate('data', len(R_i)):
R_ei_hat = R_g.transpose(1,2) @ R_j @ R_axis @ (R_i @ R_axis @ R_symi).transpose(1,2)
q_ei_hat = conversions.rotation_matrix_to_quaternion(R_ei_hat).to(dtype)
rotation_residuals = pyro.sample('rotation_residual', small_error_pn, obs = q_ei_hat.to(dtype))
print(rotation_residuals)
def mvn_model(data):
uniform_pn = dist.ProjectedNormal(torch.zeros(4, dtype=dtype))
small_error_n = dist.MultivariateNormal(torch.eye(3).reshape(9), covariance_matrix = torch.eye(9))
q_g = pyro.sample('global_rotation', uniform_pn).to(dtype)
R_g = conversions.quaternion_to_rotation_matrix(q_g).unsqueeze(0)
q_axis = pyro.sample('axis_rotation', uniform_pn).to(dtype)
R_axis = conversions.quaternion_to_rotation_matrix(q_axis).unsqueeze(0)
R_i = data['R_i']
R_j = data['R_j']
R_symi = data['R_symi']
with pyro.plate('data', len(R_i)):
R_ei_hat = R_g.transpose(1,2) @ R_j @ R_axis @ (R_i @ R_axis @ R_symi).transpose(1,2)
pyro.sample('rotation_residual', small_error_n, obs = R_ei_hat.reshape(len(R_ei_hat),9).to(dtype))
def guide(data):
global_concentration = pyro.param("global_concentration", torch.tensor([0.5,0.13,0.32,0.73], dtype=dtype))
pyro.sample('global_rotation', dist.ProjectedNormal(global_concentration))
axis_concentration = pyro.param("axis_concentration", torch.tensor([-0.0, 0.8, -0.4, 0.4], dtype=dtype))
pyro.sample('axis_rotation', dist.ProjectedNormal(axis_concentration))
# SVI setup
adam_params = {"lr": 0.0001, "clip_norm": 10.0}
optimizer = pyro.optim.ClippedAdam(adam_params)
if option == 'mvn':
model = mvn_model
elif option == 'pn':
model = pn_model
else:
raise ValueError('option must be pn or mvn')
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
losses = []
n_steps = 100
## do gradient steps
for step in range(n_steps):
if step % 10 == 0:
print(step)
losses.append(svi.step(data))
if __name__ == '__main__':
main(option='pn')
This works: main(option='mvn')
.
However this does not work: main(option='pn')
This is the error trace (you should be able to reproduce)
(datapy) (base) [gwoollard@workergpu083 latent_information_content]$ python pose_symmetry.py
0
10
Traceback (most recent call last):
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "/mnt/ceph/users/gwoollard/repos/latent_information_content/pose_symmetry.py", line 70, in guide
pyro.sample('global_rotation', dist.ProjectedNormal(global_concentration))
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/distributions/distribution.py", line 24, in __call__
return super().__call__(*args, **kwargs)
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/distributions/projected_normal.py", line 61, in __init__
super().__init__(batch_shape, event_shape, validate_args=validate_args)
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/torch/distributions/distribution.py", line 68, in __init__
raise ValueError(
ValueError: Expected parameter concentration (Tensor of shape (4,)) of distribution ProjectedNormal(concentration: torch.Size([4])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([nan, nan, nan, nan], requires_grad=True)
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/mnt/ceph/users/gwoollard/repos/latent_information_content/pose_symmetry.py", line 98, in <module>
main(option='pn')
File "/mnt/ceph/users/gwoollard/repos/latent_information_content/pose_symmetry.py", line 95, in main
losses.append(svi.step(data))
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/infer/svi.py", line 145, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/infer/trace_elbo.py", line 140, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/infer/elbo.py", line 237, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/infer/trace_elbo.py", line 57, in _get_trace
model_trace, guide_trace = get_importance_trace(
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/infer/enum.py", line 60, in get_importance_trace
guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
self(*args, **kwargs)
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
raise exc from e
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
ret = self.fn(*args, **kwargs)
File "/mnt/ceph/users/gwoollard/repos/latent_information_content/pose_symmetry.py", line 70, in guide
pyro.sample('global_rotation', dist.ProjectedNormal(global_concentration))
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/distributions/distribution.py", line 24, in __call__
return super().__call__(*args, **kwargs)
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/pyro/distributions/projected_normal.py", line 61, in __init__
super().__init__(batch_shape, event_shape, validate_args=validate_args)
File "/mnt/home/gwoollard/software/mambaforge/envs/datapy/lib/python3.10/site-packages/torch/distributions/distribution.py", line 68, in __init__
raise ValueError(
ValueError: Expected parameter concentration (Tensor of shape (4,)) of distribution ProjectedNormal(concentration: torch.Size([4])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([nan, nan, nan, nan], requires_grad=True)
Trace Shapes:
Param Sites:
global_concentration 4
Sample Sites: