Skip to content

Commit 4d52f9f

Browse files
bottlerfacebook-github-bot
authored andcommitted
matrix_to_quaternion corner case
Summary: Issue #119. The function `sqrt(max(x, 0))` is not convex and has infinite gradient at 0, but 0 is a subgradient at 0. Here we implement it in such a way as to give 0 as the gradient. Reviewed By: gkioxari Differential Revision: D24306294 fbshipit-source-id: 48d136faca083babad4d64970be7ea522dbe9e09
1 parent 2d39723 commit 4d52f9f

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

pytorch3d/transforms/rotation_conversions.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ def _copysign(a, b):
8282
return torch.where(signs_differ, -a, a)
8383

8484

85+
def _sqrt_positive_part(x):
86+
"""
87+
Returns torch.sqrt(torch.max(0, x))
88+
but with a zero subgradient where x is 0.
89+
"""
90+
ret = torch.zeros_like(x)
91+
positive_mask = x > 0
92+
ret[positive_mask] = torch.sqrt(x[positive_mask])
93+
return ret
94+
95+
8596
def matrix_to_quaternion(matrix):
8697
"""
8798
Convert rotations given as rotation matrices to quaternions.
@@ -94,14 +105,13 @@ def matrix_to_quaternion(matrix):
94105
"""
95106
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
96107
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
97-
zero = matrix.new_zeros((1,))
98108
m00 = matrix[..., 0, 0]
99109
m11 = matrix[..., 1, 1]
100110
m22 = matrix[..., 2, 2]
101-
o0 = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 + m11 + m22))
102-
x = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 - m11 - m22))
103-
y = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 + m11 - m22))
104-
z = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 - m11 + m22))
111+
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
112+
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
113+
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
114+
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
105115
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
106116
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
107117
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])

tests/test_rotation_conversions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,20 @@ def test_quaternion_multiplication(self):
145145
self.assertEqual(ab.shape, ab_from_matrix.shape)
146146
self.assertTrue(torch.allclose(ab, ab_from_matrix))
147147

148+
def test_matrix_to_quaternion_corner_case(self):
149+
"""Check no bad gradients from sqrt(0)."""
150+
matrix = torch.eye(3, requires_grad=True)
151+
target = torch.Tensor([0.984808, 0, 0.174, 0])
152+
153+
optimizer = torch.optim.Adam([matrix], lr=0.05)
154+
optimizer.zero_grad()
155+
q = matrix_to_quaternion(matrix)
156+
loss = torch.sum((q - target) ** 2)
157+
loss.backward()
158+
optimizer.step()
159+
160+
self.assertClose(matrix, 0.95 * torch.eye(3))
161+
148162
def test_quaternion_application(self):
149163
"""Applying a quaternion is the same as applying the matrix."""
150164
quaternions = random_quaternions(3, torch.float64, requires_grad=True)

0 commit comments

Comments
 (0)