8
8
import torch
9
9
from common_testing import TestCaseMixin
10
10
from pytorch3d .transforms .rotation_conversions import (
11
+ axis_angle_to_matrix ,
12
+ axis_angle_to_quaternion ,
11
13
euler_angles_to_matrix ,
14
+ matrix_to_axis_angle ,
12
15
matrix_to_euler_angles ,
13
16
matrix_to_quaternion ,
14
17
matrix_to_rotation_6d ,
15
18
quaternion_apply ,
16
19
quaternion_multiply ,
20
+ quaternion_to_axis_angle ,
17
21
quaternion_to_matrix ,
18
22
random_quaternions ,
19
23
random_rotation ,
@@ -60,13 +64,13 @@ def test_from_quat(self):
60
64
"""quat -> mtx -> quat"""
61
65
data = random_quaternions (13 , dtype = torch .float64 )
62
66
mdata = matrix_to_quaternion (quaternion_to_matrix (data ))
63
- self .assertTrue ( torch . allclose ( data , mdata ) )
67
+ self .assertClose ( data , mdata )
64
68
65
69
def test_to_quat (self ):
66
70
"""mtx -> quat -> mtx"""
67
71
data = random_rotations (13 , dtype = torch .float64 )
68
72
mdata = quaternion_to_matrix (matrix_to_quaternion (data ))
69
- self .assertTrue ( torch . allclose ( data , mdata ) )
73
+ self .assertClose ( data , mdata )
70
74
71
75
def test_quat_grad_exists (self ):
72
76
"""Quaternion calculations are differentiable."""
@@ -107,21 +111,21 @@ def test_from_euler(self):
107
111
for convention in self ._tait_bryan_conventions ():
108
112
matrices = euler_angles_to_matrix (data , convention )
109
113
mdata = matrix_to_euler_angles (matrices , convention )
110
- self .assertTrue ( torch . allclose ( data , mdata ) )
114
+ self .assertClose ( data , mdata )
111
115
112
116
data [:, 1 ] += half_pi
113
117
for convention in self ._proper_euler_conventions ():
114
118
matrices = euler_angles_to_matrix (data , convention )
115
119
mdata = matrix_to_euler_angles (matrices , convention )
116
- self .assertTrue ( torch . allclose ( data , mdata ) )
120
+ self .assertClose ( data , mdata )
117
121
118
122
def test_to_euler (self ):
119
123
"""mtx -> euler -> mtx"""
120
124
data = random_rotations (13 , dtype = torch .float64 )
121
125
for convention in self ._all_euler_angle_conventions ():
122
126
euler_angles = matrix_to_euler_angles (data , convention )
123
127
mdata = euler_angles_to_matrix (euler_angles , convention )
124
- self .assertTrue ( torch . allclose ( data , mdata ) )
128
+ self .assertClose ( data , mdata )
125
129
126
130
def test_euler_grad_exists (self ):
127
131
"""Euler angle calculations are differentiable."""
@@ -143,7 +147,7 @@ def test_quaternion_multiplication(self):
143
147
ab_matrix = torch .matmul (a_matrix , b_matrix )
144
148
ab_from_matrix = matrix_to_quaternion (ab_matrix )
145
149
self .assertEqual (ab .shape , ab_from_matrix .shape )
146
- self .assertTrue ( torch . allclose ( ab , ab_from_matrix ) )
150
+ self .assertClose ( ab , ab_from_matrix )
147
151
148
152
def test_matrix_to_quaternion_corner_case (self ):
149
153
"""Check no bad gradients from sqrt(0)."""
@@ -159,14 +163,39 @@ def test_matrix_to_quaternion_corner_case(self):
159
163
160
164
self .assertClose (matrix , 0.95 * torch .eye (3 ))
161
165
166
+ def test_from_axis_angle (self ):
167
+ """axis_angle -> mtx -> axis_angle"""
168
+ n_repetitions = 20
169
+ data = torch .rand (n_repetitions , 3 )
170
+ matrices = axis_angle_to_matrix (data )
171
+ mdata = matrix_to_axis_angle (matrices )
172
+ self .assertClose (data , mdata , atol = 2e-6 )
173
+
174
+ def test_from_axis_angle_has_grad (self ):
175
+ n_repetitions = 20
176
+ data = torch .rand (n_repetitions , 3 , requires_grad = True )
177
+ matrices = axis_angle_to_matrix (data )
178
+ mdata = matrix_to_axis_angle (matrices )
179
+ quats = axis_angle_to_quaternion (data )
180
+ mdata2 = quaternion_to_axis_angle (quats )
181
+ (grad ,) = torch .autograd .grad (mdata .sum () + mdata2 .sum (), data )
182
+ self .assertTrue (torch .isfinite (grad ).all ())
183
+
184
+ def test_to_axis_angle (self ):
185
+ """mtx -> axis_angle -> mtx"""
186
+ data = random_rotations (13 , dtype = torch .float64 )
187
+ euler_angles = matrix_to_axis_angle (data )
188
+ mdata = axis_angle_to_matrix (euler_angles )
189
+ self .assertClose (data , mdata )
190
+
162
191
def test_quaternion_application (self ):
163
192
"""Applying a quaternion is the same as applying the matrix."""
164
193
quaternions = random_quaternions (3 , torch .float64 , requires_grad = True )
165
194
matrices = quaternion_to_matrix (quaternions )
166
195
points = torch .randn (3 , 3 , dtype = torch .float64 , requires_grad = True )
167
196
transform1 = quaternion_apply (quaternions , points )
168
197
transform2 = torch .matmul (matrices , points [..., None ])[..., 0 ]
169
- self .assertTrue ( torch . allclose ( transform1 , transform2 ) )
198
+ self .assertClose ( transform1 , transform2 )
170
199
171
200
[p , q ] = torch .autograd .grad (transform1 .sum (), [points , quaternions ])
172
201
self .assertTrue (torch .isfinite (p ).all ())
0 commit comments