Skip to content

Commit c93c4dd

Browse files
bottlerfacebook-github-bot
authored andcommitted
axis_angle representation of rotations
Summary: We can represent a rotation as a vector in the axis direction, whose length is the rotation anticlockwise in radians around that axis. Reviewed By: gkioxari Differential Revision: D24306293 fbshipit-source-id: 2e0f138eda8329f6cceff600a6e5f17a00e4deb7
1 parent 005a334 commit c93c4dd

File tree

2 files changed

+131
-7
lines changed

2 files changed

+131
-7
lines changed

pytorch3d/transforms/rotation_conversions.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,101 @@ def quaternion_apply(quaternion, point):
413413
return out[..., 1:]
414414

415415

416+
def axis_angle_to_matrix(axis_angle):
417+
"""
418+
Convert rotations given as axis/angle to rotation matrices.
419+
420+
Args:
421+
axis_angle: Rotations given as a vector in axis angle form,
422+
as a tensor of shape (..., 3), where the magnitude is
423+
the angle turned anticlockwise in radians around the
424+
vector's direction.
425+
426+
Returns:
427+
Rotation matrices as tensor of shape (..., 3, 3).
428+
"""
429+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
430+
431+
432+
def matrix_to_axis_angle(matrix):
433+
"""
434+
Convert rotations given as rotation matrices to axis/angle.
435+
436+
Args:
437+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
438+
439+
Returns:
440+
Rotations given as a vector in axis angle form, as a tensor
441+
of shape (..., 3), where the magnitude is the angle
442+
turned anticlockwise in radians around the vector's
443+
direction.
444+
"""
445+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
446+
447+
448+
def axis_angle_to_quaternion(axis_angle):
449+
"""
450+
Convert rotations given as axis/angle to quaternions.
451+
452+
Args:
453+
axis_angle: Rotations given as a vector in axis angle form,
454+
as a tensor of shape (..., 3), where the magnitude is
455+
the angle turned anticlockwise in radians around the
456+
vector's direction.
457+
458+
Returns:
459+
quaternions with real part first, as tensor of shape (..., 4).
460+
"""
461+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
462+
half_angles = 0.5 * angles
463+
eps = 1e-6
464+
small_angles = angles.abs() < eps
465+
sin_half_angles_over_angles = torch.empty_like(angles)
466+
sin_half_angles_over_angles[~small_angles] = (
467+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
468+
)
469+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
470+
# so sin(x/2)/x is about 1/2 - (x*x)/48
471+
sin_half_angles_over_angles[small_angles] = (
472+
0.5 - torch.square(angles[small_angles]) / 48
473+
)
474+
quaternions = torch.cat(
475+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
476+
)
477+
return quaternions
478+
479+
480+
def quaternion_to_axis_angle(quaternions):
481+
"""
482+
Convert rotations given as quaternions to axis/angle.
483+
484+
Args:
485+
quaternions: quaternions with real part first,
486+
as tensor of shape (..., 4).
487+
488+
Returns:
489+
Rotations given as a vector in axis angle form, as a tensor
490+
of shape (..., 3), where the magnitude is the angle
491+
turned anticlockwise in radians around the vector's
492+
direction.
493+
"""
494+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
495+
half_angles = torch.atan2(norms, quaternions[..., :1])
496+
angles = 2 * half_angles
497+
eps = 1e-6
498+
small_angles = angles.abs() < eps
499+
sin_half_angles_over_angles = torch.empty_like(angles)
500+
sin_half_angles_over_angles[~small_angles] = (
501+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
502+
)
503+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
504+
# so sin(x/2)/x is about 1/2 - (x*x)/48
505+
sin_half_angles_over_angles[small_angles] = (
506+
0.5 - torch.square(angles[small_angles]) / 48
507+
)
508+
return quaternions[..., 1:] / sin_half_angles_over_angles
509+
510+
416511
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
417512
"""
418513
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix

tests/test_rotation_conversions.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@
88
import torch
99
from common_testing import TestCaseMixin
1010
from pytorch3d.transforms.rotation_conversions import (
11+
axis_angle_to_matrix,
12+
axis_angle_to_quaternion,
1113
euler_angles_to_matrix,
14+
matrix_to_axis_angle,
1215
matrix_to_euler_angles,
1316
matrix_to_quaternion,
1417
matrix_to_rotation_6d,
1518
quaternion_apply,
1619
quaternion_multiply,
20+
quaternion_to_axis_angle,
1721
quaternion_to_matrix,
1822
random_quaternions,
1923
random_rotation,
@@ -60,13 +64,13 @@ def test_from_quat(self):
6064
"""quat -> mtx -> quat"""
6165
data = random_quaternions(13, dtype=torch.float64)
6266
mdata = matrix_to_quaternion(quaternion_to_matrix(data))
63-
self.assertTrue(torch.allclose(data, mdata))
67+
self.assertClose(data, mdata)
6468

6569
def test_to_quat(self):
6670
"""mtx -> quat -> mtx"""
6771
data = random_rotations(13, dtype=torch.float64)
6872
mdata = quaternion_to_matrix(matrix_to_quaternion(data))
69-
self.assertTrue(torch.allclose(data, mdata))
73+
self.assertClose(data, mdata)
7074

7175
def test_quat_grad_exists(self):
7276
"""Quaternion calculations are differentiable."""
@@ -107,21 +111,21 @@ def test_from_euler(self):
107111
for convention in self._tait_bryan_conventions():
108112
matrices = euler_angles_to_matrix(data, convention)
109113
mdata = matrix_to_euler_angles(matrices, convention)
110-
self.assertTrue(torch.allclose(data, mdata))
114+
self.assertClose(data, mdata)
111115

112116
data[:, 1] += half_pi
113117
for convention in self._proper_euler_conventions():
114118
matrices = euler_angles_to_matrix(data, convention)
115119
mdata = matrix_to_euler_angles(matrices, convention)
116-
self.assertTrue(torch.allclose(data, mdata))
120+
self.assertClose(data, mdata)
117121

118122
def test_to_euler(self):
119123
"""mtx -> euler -> mtx"""
120124
data = random_rotations(13, dtype=torch.float64)
121125
for convention in self._all_euler_angle_conventions():
122126
euler_angles = matrix_to_euler_angles(data, convention)
123127
mdata = euler_angles_to_matrix(euler_angles, convention)
124-
self.assertTrue(torch.allclose(data, mdata))
128+
self.assertClose(data, mdata)
125129

126130
def test_euler_grad_exists(self):
127131
"""Euler angle calculations are differentiable."""
@@ -143,7 +147,7 @@ def test_quaternion_multiplication(self):
143147
ab_matrix = torch.matmul(a_matrix, b_matrix)
144148
ab_from_matrix = matrix_to_quaternion(ab_matrix)
145149
self.assertEqual(ab.shape, ab_from_matrix.shape)
146-
self.assertTrue(torch.allclose(ab, ab_from_matrix))
150+
self.assertClose(ab, ab_from_matrix)
147151

148152
def test_matrix_to_quaternion_corner_case(self):
149153
"""Check no bad gradients from sqrt(0)."""
@@ -159,14 +163,39 @@ def test_matrix_to_quaternion_corner_case(self):
159163

160164
self.assertClose(matrix, 0.95 * torch.eye(3))
161165

166+
def test_from_axis_angle(self):
167+
"""axis_angle -> mtx -> axis_angle"""
168+
n_repetitions = 20
169+
data = torch.rand(n_repetitions, 3)
170+
matrices = axis_angle_to_matrix(data)
171+
mdata = matrix_to_axis_angle(matrices)
172+
self.assertClose(data, mdata, atol=2e-6)
173+
174+
def test_from_axis_angle_has_grad(self):
175+
n_repetitions = 20
176+
data = torch.rand(n_repetitions, 3, requires_grad=True)
177+
matrices = axis_angle_to_matrix(data)
178+
mdata = matrix_to_axis_angle(matrices)
179+
quats = axis_angle_to_quaternion(data)
180+
mdata2 = quaternion_to_axis_angle(quats)
181+
(grad,) = torch.autograd.grad(mdata.sum() + mdata2.sum(), data)
182+
self.assertTrue(torch.isfinite(grad).all())
183+
184+
def test_to_axis_angle(self):
185+
"""mtx -> axis_angle -> mtx"""
186+
data = random_rotations(13, dtype=torch.float64)
187+
euler_angles = matrix_to_axis_angle(data)
188+
mdata = axis_angle_to_matrix(euler_angles)
189+
self.assertClose(data, mdata)
190+
162191
def test_quaternion_application(self):
163192
"""Applying a quaternion is the same as applying the matrix."""
164193
quaternions = random_quaternions(3, torch.float64, requires_grad=True)
165194
matrices = quaternion_to_matrix(quaternions)
166195
points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
167196
transform1 = quaternion_apply(quaternions, points)
168197
transform2 = torch.matmul(matrices, points[..., None])[..., 0]
169-
self.assertTrue(torch.allclose(transform1, transform2))
198+
self.assertClose(transform1, transform2)
170199

171200
[p, q] = torch.autograd.grad(transform1.sum(), [points, quaternions])
172201
self.assertTrue(torch.isfinite(p).all())

0 commit comments

Comments
 (0)