Skip to content

Commit dd45123

Browse files
davnov134facebook-github-bot
authored andcommitted
Linearly extrapolated acos.
Summary: Implements a backprop-safe version of `torch.acos` that linearly extrapolates the function outside bounds. Below is a plot of the extrapolated acos for different bounds: {F611339485} Reviewed By: bottler, nikhilaravi Differential Revision: D27945714 fbshipit-source-id: fa2e2385b56d6fe534338d5192447c4a3aec540c
1 parent 88f5d79 commit dd45123

File tree

4 files changed

+246
-0
lines changed

4 files changed

+246
-0
lines changed

pytorch3d/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
from .math import acos_linear_extrapolation
34
from .rotation_conversions import (
45
axis_angle_to_matrix,
56
axis_angle_to_quaternion,

pytorch3d/transforms/math.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
import math
3+
from typing import Tuple, Union
4+
5+
import torch
6+
7+
8+
def acos_linear_extrapolation(
9+
x: torch.Tensor,
10+
bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4,
11+
) -> torch.Tensor:
12+
"""
13+
Implements `arccos(x)` which is linearly extrapolated outside `x`'s original
14+
domain of `(-1, 1)`. This allows for stable backpropagation in case `x`
15+
is not guaranteed to be strictly within `(-1, 1)`.
16+
17+
More specifically:
18+
```
19+
if -bound <= x <= bound:
20+
acos_linear_extrapolation(x) = acos(x)
21+
elif x <= -bound: # 1st order Taylor approximation
22+
acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound))
23+
else: # x >= bound
24+
acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound)
25+
```
26+
Note that `bound` can be made more specific with setting
27+
`bound=[lower_bound, upper_bound]` as detailed below.
28+
29+
Args:
30+
x: Input `Tensor`.
31+
bound: A float constant or a float 2-tuple defining the region for the
32+
linear extrapolation of `acos`.
33+
If `bound` is a float scalar, linearly interpolates acos for
34+
`x <= -bound` or `bound <= x`.
35+
If `bound` is a 2-tuple, the first/second element of `bound`
36+
describes the lower/upper bound that defines the lower/upper
37+
extrapolation region, i.e. the region where
38+
`x <= bound[0]`/`bound[1] <= x`.
39+
Note that all elements of `bound` have to be within (-1, 1).
40+
Returns:
41+
acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`.
42+
"""
43+
44+
if isinstance(bound, float):
45+
upper_bound = bound
46+
lower_bound = -bound
47+
else:
48+
lower_bound, upper_bound = bound
49+
50+
if lower_bound > upper_bound:
51+
raise ValueError("lower bound has to be smaller or equal to upper bound.")
52+
53+
if lower_bound <= -1.0 or upper_bound >= 1.0:
54+
raise ValueError("Both lower bound and upper bound have to be within (-1, 1).")
55+
56+
# init an empty tensor and define the domain sets
57+
acos_extrap = torch.empty_like(x)
58+
x_upper = x >= upper_bound
59+
x_lower = x <= lower_bound
60+
x_mid = (~x_upper) & (~x_lower)
61+
62+
# acos calculation for upper_bound < x < lower_bound
63+
acos_extrap[x_mid] = torch.acos(x[x_mid])
64+
# the linear extrapolation for x >= upper_bound
65+
acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound)
66+
# the linear extrapolation for x <= lower_bound
67+
acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound)
68+
69+
return acos_extrap
70+
71+
72+
def _acos_linear_approximation(x: torch.Tensor, x0: float) -> torch.Tensor:
73+
"""
74+
Calculates the 1st order Taylor expansion of `arccos(x)` around `x0`.
75+
"""
76+
return (x - x0) * _dacos_dx(x0) + math.acos(x0)
77+
78+
79+
def _dacos_dx(x: float) -> float:
80+
"""
81+
Calculates the derivative of `arccos(x)` w.r.t. `x`.
82+
"""
83+
return (-1.0) / math.sqrt(1.0 - x * x)

tests/bm_acos_linear_extrapolation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
from fvcore.common.benchmark import benchmark
4+
from test_acos_linear_extrapolation import TestAcosLinearExtrapolation
5+
6+
7+
def bm_acos_linear_extrapolation() -> None:
8+
kwargs_list = [
9+
{"batch_size": 1},
10+
{"batch_size": 100},
11+
{"batch_size": 10000},
12+
{"batch_size": 1000000},
13+
]
14+
benchmark(
15+
TestAcosLinearExtrapolation.acos_linear_extrapolation,
16+
"ACOS_LINEAR_EXTRAPOLATION",
17+
kwargs_list,
18+
warmup_iters=1,
19+
)
20+
21+
22+
if __name__ == "__main__":
23+
bm_acos_linear_extrapolation()
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
4+
import unittest
5+
6+
import numpy as np
7+
import torch
8+
from common_testing import TestCaseMixin
9+
from pytorch3d.transforms import acos_linear_extrapolation
10+
11+
12+
class TestAcosLinearExtrapolation(TestCaseMixin, unittest.TestCase):
13+
def setUp(self) -> None:
14+
super().setUp()
15+
torch.manual_seed(42)
16+
np.random.seed(42)
17+
18+
@staticmethod
19+
def init_acos_boundary_values(batch_size: int = 10000):
20+
"""
21+
Initialize a tensor containing values close to the bounds of the
22+
domain of `acos`, i.e. close to -1 or 1; and random values between (-1, 1).
23+
"""
24+
device = torch.device("cuda:0")
25+
# one quarter are random values between -1 and 1
26+
x_rand = 2 * torch.rand(batch_size // 4, dtype=torch.float32, device=device) - 1
27+
x = [x_rand]
28+
for bound in [-1, 1]:
29+
for above_bound in [True, False]:
30+
for noise_std in [1e-4, 1e-2]:
31+
n_generate = (batch_size - batch_size // 4) // 8
32+
x_add = (
33+
bound
34+
+ (2 * float(above_bound) - 1)
35+
* torch.randn(
36+
n_generate, device=device, dtype=torch.float32
37+
).abs()
38+
* noise_std
39+
)
40+
x.append(x_add)
41+
x = torch.cat(x)
42+
return x
43+
44+
@staticmethod
45+
def acos_linear_extrapolation(batch_size: int):
46+
x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size)
47+
torch.cuda.synchronize()
48+
49+
def compute_acos():
50+
acos_linear_extrapolation(x)
51+
torch.cuda.synchronize()
52+
53+
return compute_acos
54+
55+
def _test_acos_outside_bounds(self, x, y, dydx, bound):
56+
"""
57+
Check that `acos_linear_extrapolation` yields points on a line with correct
58+
slope, and that the function is continuous around `bound`.
59+
"""
60+
bound_t = torch.tensor(bound, device=x.device, dtype=x.dtype)
61+
# fit a line: slope * x + bias = y
62+
x_1 = torch.stack([x, torch.ones_like(x)], dim=-1)
63+
solution = torch.linalg.lstsq(x_1, y[:, None]).solution
64+
slope, bias = solution.view(-1)[:2]
65+
desired_slope = (-1.0) / torch.sqrt(1.0 - bound_t ** 2)
66+
# test that the desired slope is the same as the fitted one
67+
self.assertClose(desired_slope.view(1), slope.view(1), atol=1e-2)
68+
# test that the autograd's slope is the same as the desired one
69+
self.assertClose(desired_slope.expand_as(dydx), dydx, atol=1e-2)
70+
# test that the value of the fitted line at x=bound equals
71+
# arccos(x), i.e. the function is continuous around the bound
72+
y_bound_lin = (slope * bound_t + bias).view(1)
73+
y_bound_acos = bound_t.acos().view(1)
74+
self.assertClose(y_bound_lin, y_bound_acos, atol=1e-2)
75+
76+
def _one_acos_test(self, x: torch.Tensor, lower_bound: float, upper_bound: float):
77+
"""
78+
Test that `acos_linear_extrapolation` returns correct values for
79+
`x` between/above/below `lower_bound`/`upper_bound`.
80+
"""
81+
x.requires_grad = True
82+
x.grad = None
83+
y = acos_linear_extrapolation(x, [lower_bound, upper_bound])
84+
# compute the gradient of the acos w.r.t. x
85+
y.backward(torch.ones_like(y))
86+
dacos_dx = x.grad
87+
x_lower = x <= lower_bound
88+
x_upper = x >= upper_bound
89+
x_mid = (~x_lower) & (~x_upper)
90+
# test that between bounds, the function returns plain acos
91+
self.assertClose(x[x_mid].acos(), y[x_mid])
92+
# test that outside the bounds, the function is linear with the right
93+
# slope and continuous around the bound
94+
self._test_acos_outside_bounds(
95+
x[x_upper], y[x_upper], dacos_dx[x_upper], upper_bound
96+
)
97+
self._test_acos_outside_bounds(
98+
x[x_lower], y[x_lower], dacos_dx[x_lower], lower_bound
99+
)
100+
if abs(upper_bound + lower_bound) <= 1e-5: # lower_bound==-upper_bound
101+
# check that passing bounds=upper_bound gives the same
102+
# resut as bounds=[lower_bound, upper_bound]
103+
y_one_bound = acos_linear_extrapolation(x, upper_bound)
104+
self.assertClose(y_one_bound, y)
105+
106+
def test_acos(self, batch_size: int = 10000):
107+
"""
108+
Tests whether the function returns correct outputs
109+
inside/outside the bounds.
110+
"""
111+
x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size)
112+
bounds = 1 - 10.0 ** torch.linspace(-1, -5, 5)
113+
for lower_bound in -bounds:
114+
for upper_bound in bounds:
115+
if upper_bound < lower_bound:
116+
continue
117+
self._one_acos_test(x, float(lower_bound), float(upper_bound))
118+
119+
def test_finite_gradient(self, batch_size: int = 10000):
120+
"""
121+
Tests whether gradients stay finite close to the bounds.
122+
"""
123+
x = TestAcosLinearExtrapolation.init_acos_boundary_values(batch_size)
124+
x.requires_grad = True
125+
bounds = 1 - 10.0 ** torch.linspace(-1, -5, 5)
126+
for lower_bound in -bounds:
127+
for upper_bound in bounds:
128+
if upper_bound < lower_bound:
129+
continue
130+
x.grad = None
131+
y = acos_linear_extrapolation(
132+
x,
133+
[float(lower_bound), float(upper_bound)],
134+
)
135+
self.assertTrue(torch.isfinite(y).all())
136+
loss = y.mean()
137+
loss.backward()
138+
self.assertIsNotNone(x.grad)
139+
self.assertTrue(torch.isfinite(x.grad).all())

0 commit comments

Comments
 (0)