Skip to content

Commit 316b777

Browse files
davnov134facebook-github-bot
authored andcommitted
Camera alignment
Summary: adds `corresponding_cameras_alignment` function that estimates a similarity transformation between two sets of cameras. The function is essential for computing camera errors in SfM pipelines. ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- CORRESPONDING_CAMERAS_ALIGNMENT_10_centers_False 32219 36211 16 CORRESPONDING_CAMERAS_ALIGNMENT_10_centers_True 32429 36063 16 CORRESPONDING_CAMERAS_ALIGNMENT_10_extrinsics_False 5548 8782 91 CORRESPONDING_CAMERAS_ALIGNMENT_10_extrinsics_True 6153 9752 82 CORRESPONDING_CAMERAS_ALIGNMENT_100_centers_False 33344 40398 16 CORRESPONDING_CAMERAS_ALIGNMENT_100_centers_True 34528 37095 15 CORRESPONDING_CAMERAS_ALIGNMENT_100_extrinsics_False 5576 7187 90 CORRESPONDING_CAMERAS_ALIGNMENT_100_extrinsics_True 6256 9166 80 CORRESPONDING_CAMERAS_ALIGNMENT_1000_centers_False 32020 37247 16 CORRESPONDING_CAMERAS_ALIGNMENT_1000_centers_True 32776 37644 16 CORRESPONDING_CAMERAS_ALIGNMENT_1000_extrinsics_False 5336 8795 94 CORRESPONDING_CAMERAS_ALIGNMENT_1000_extrinsics_True 6266 9929 80 -------------------------------------------------------------------------------- ``` Reviewed By: shapovalov Differential Revision: D22946415 fbshipit-source-id: 8caae7ee365b304d8aa1f8133cf0dd92c35bc0dd
1 parent 14f015d commit 316b777

File tree

6 files changed

+482
-65
lines changed

6 files changed

+482
-65
lines changed

pytorch3d/ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
3+
from .cameras_alignment import corresponding_cameras_alignment
44
from .cubify import cubify
55
from .graph_conv import GraphConv
66
from .interp_face_attrs import interpolate_face_attributes

pytorch3d/ops/cameras_alignment.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
7+
from .. import ops
8+
9+
10+
if TYPE_CHECKING:
11+
from pytorch3d.renderer.cameras import CamerasBase
12+
13+
14+
def corresponding_cameras_alignment(
15+
cameras_src: "CamerasBase",
16+
cameras_tgt: "CamerasBase",
17+
estimate_scale: bool = True,
18+
mode: str = "extrinsics",
19+
eps: float = 1e-9,
20+
) -> "CamerasBase":
21+
"""
22+
.. warning::
23+
The `corresponding_cameras_alignment` API is experimental
24+
and subject to change!
25+
26+
Estimates a single similarity transformation between two sets of cameras
27+
`cameras_src` and `cameras_tgt` and returns an aligned version of
28+
`cameras_src`.
29+
30+
Given source cameras [(R_1, T_1), (R_2, T_2), ..., (R_N, T_N)] and
31+
target cameras [(R_1', T_1'), (R_2', T_2'), ..., (R_N', T_N')],
32+
where (R_i, T_i) is a 2-tuple of the camera rotation and translation matrix
33+
respectively, the algorithm finds a global rotation, translation and scale
34+
(R_A, T_A, s_A) which aligns all source cameras with the target cameras
35+
such that the following holds:
36+
37+
Under the change of coordinates using a similarity transform
38+
(R_A, T_A, s_A) a 3D point X' is mapped to X with:
39+
```
40+
X = (X' R_A + T_A) / s_A
41+
```
42+
Then, for all cameras `i`, we assume that the following holds:
43+
```
44+
X R_i + T_i = s' (X' R_i' + T_i'),
45+
```
46+
i.e. an adjusted point X' is mapped by a camera (R_i', T_i')
47+
to the same point as imaged from camera (R_i, T_i) after resolving
48+
the scale ambiguity with a global scalar factor s'.
49+
50+
Substituting for X above gives rise to the following:
51+
```
52+
(X' R_A + T_A) / s_A R_i + T_i = s' (X' R_i' + T_i') // · s_A
53+
(X' R_A + T_A) R_i + T_i s_A = (s' s_A) (X' R_i' + T_i')
54+
s' := 1 / s_A # without loss of generality
55+
(X' R_A + T_A) R_i + T_i s_A = X' R_i' + T_i'
56+
X' R_A R_i + T_A R_i + T_i s_A = X' R_i' + T_i'
57+
^^^^^^^ ^^^^^^^^^^^^^^^^^
58+
~= R_i' ~= T_i'
59+
```
60+
i.e. after estimating R_A, T_A, s_A, the aligned source cameras have
61+
extrinsics:
62+
`cameras_src_align = (R_A R_i, T_A R_i + T_i s_A) ~= (R_i', T_i')`
63+
64+
We support two ways `R_A, T_A, s_A` can be estimated:
65+
1) `mode=='centers'`
66+
Estimates the similarity alignment between camera centers using
67+
Umeyama's algorithm (see `pytorch3d.ops.corresponding_points_alignment`
68+
for details) and transforms camera extrinsics accordingly.
69+
70+
2) `mode=='extrinsics'`
71+
Defines the alignment problem as a system
72+
of the following equations:
73+
```
74+
for all i:
75+
[ R_A 0 ] x [ R_i 0 ] = [ R_i' 0 ]
76+
[ T_A^T 1 ] [ (s_A T_i^T) 1 ] [ T_i' 1 ]
77+
```
78+
`R_A, T_A` and `s_A` are then obtained by solving the
79+
system in the least squares sense.
80+
81+
The estimated camera transformation is a true similarity transform, i.e.
82+
it cannot be a reflection.
83+
84+
Args:
85+
cameras_src: `N` cameras to be aligned.
86+
cameras_tgt: `N` target cameras.
87+
estimate_scale: Controls whether the alignment transform is rigid
88+
(`estimate_scale=False`), or a similarity (`estimate_scale=True`).
89+
`s_A` is set to `1` if `estimate_scale==False`.
90+
mode: Controls the alignment algorithm.
91+
Can be one either `'centers'` or `'extrinsics'`. Please refer to the
92+
description above for details.
93+
eps: A scalar for clamping to avoid dividing by zero.
94+
Active when `estimate_scale==True`.
95+
96+
Returns:
97+
cameras_src_aligned: `cameras_src` after applying the alignment transform.
98+
"""
99+
100+
if cameras_src.R.shape[0] != cameras_tgt.R.shape[0]:
101+
raise ValueError(
102+
"cameras_src and cameras_tgt have to contain the same number of cameras!"
103+
)
104+
105+
if mode == "centers":
106+
align_fun = _align_camera_centers
107+
elif mode == "extrinsics":
108+
align_fun = _align_camera_extrinsics
109+
else:
110+
raise ValueError("mode has to be one of (centers, extrinsics)")
111+
112+
align_t_R, align_t_T, align_t_s = align_fun(
113+
cameras_src, cameras_tgt, estimate_scale=estimate_scale, eps=eps
114+
)
115+
116+
# create a new cameras object and set the R and T accordingly
117+
cameras_src_aligned = cameras_src.clone()
118+
cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R)
119+
cameras_src_aligned.T = (
120+
torch.bmm(
121+
align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1), cameras_src.R
122+
)[:, 0]
123+
+ cameras_src.T * align_t_s
124+
)
125+
126+
return cameras_src_aligned
127+
128+
129+
def _align_camera_centers(
130+
cameras_src: "CamerasBase",
131+
cameras_tgt: "CamerasBase",
132+
estimate_scale: bool = True,
133+
eps: float = 1e-9,
134+
):
135+
"""
136+
Use Umeyama's algorithm to align the camera centers.
137+
"""
138+
centers_src = cameras_src.get_camera_center()
139+
centers_tgt = cameras_tgt.get_camera_center()
140+
align_t = ops.corresponding_points_alignment(
141+
centers_src[None],
142+
centers_tgt[None],
143+
estimate_scale=estimate_scale,
144+
allow_reflection=False,
145+
eps=eps,
146+
)
147+
# the camera transform is the inverse of the estimated transform between centers
148+
align_t_R = align_t.R.permute(0, 2, 1)
149+
align_t_T = -(torch.bmm(align_t.T[:, None], align_t_R))[:, 0]
150+
align_t_s = align_t.s[0]
151+
152+
return align_t_R, align_t_T, align_t_s
153+
154+
155+
def _align_camera_extrinsics(
156+
cameras_src: "CamerasBase",
157+
cameras_tgt: "CamerasBase",
158+
estimate_scale: bool = True,
159+
eps: float = 1e-9,
160+
):
161+
"""
162+
Get the global rotation R_A with svd of cov(RR^T):
163+
```
164+
R_A R_i = R_i' for all i
165+
R_A [R_1 R_2 ... R_N] = [R_1' R_2' ... R_N']
166+
U, _, V = svd([R_1 R_2 ... R_N]^T [R_1' R_2' ... R_N'])
167+
R_A = (U V^T)^T
168+
```
169+
"""
170+
RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0)
171+
U, _, V = torch.svd(RRcov)
172+
align_t_R = V @ U.t()
173+
174+
"""
175+
The translation + scale `T_A` and `s_A` is computed by finding
176+
a translation and scaling that aligns two tensors `A, B`
177+
defined as follows:
178+
```
179+
T_A R_i + s_A T_i = T_i' ; for all i // · R_i^T
180+
s_A T_i R_i^T + T_A = T_i' R_i^T ; for all i
181+
^^^^^^^^^ ^^^^^^^^^^
182+
A_i B_i
183+
184+
A_i := T_i R_i^T
185+
A = [A_1 A_2 ... A_N]
186+
B_i := T_i' R_i^T
187+
B = [B_1 B_2 ... B_N]
188+
```
189+
The scale s_A can be retrieved by matching the correlations of
190+
the points sets A and B:
191+
```
192+
s_A = (A-mean(A))*(B-mean(B)).sum() / ((A-mean(A))**2).sum()
193+
```
194+
The translation `T_A` is then defined as:
195+
```
196+
T_A = mean(B) - mean(A) * s_A
197+
```
198+
"""
199+
A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0]
200+
B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0]
201+
Amu = A.mean(0, keepdim=True)
202+
Bmu = B.mean(0, keepdim=True)
203+
if estimate_scale and A.shape[0] > 1:
204+
# get the scaling component by matching covariances
205+
# of centered A and centered B
206+
Ac = A - Amu
207+
Bc = B - Bmu
208+
align_t_s = (Ac * Bc).mean() / (Ac ** 2).mean().clamp(eps)
209+
else:
210+
# set the scale to identity
211+
align_t_s = 1.0
212+
# get the translation as the difference between the means of A and B
213+
align_t_T = Bmu - align_t_s * Amu
214+
215+
return align_t_R, align_t_T, align_t_s

pytorch3d/renderer/cameras.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414

1515
# Default values for rotation and translation matrices.
16-
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
17-
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
16+
_R = torch.eye(3)[None] # (1, 3, 3)
17+
_T = torch.zeros(1, 3) # (1, 3)
1818

1919

2020
class CamerasBase(TensorProperties):
@@ -280,8 +280,8 @@ def OpenGLPerspectiveCameras(
280280
aspect_ratio=1.0,
281281
fov=60.0,
282282
degrees: bool = True,
283-
R=r,
284-
T=t,
283+
R=_R,
284+
T=_T,
285285
device="cpu",
286286
):
287287
"""
@@ -331,8 +331,8 @@ def __init__(
331331
aspect_ratio=1.0,
332332
fov=60.0,
333333
degrees: bool = True,
334-
R=r,
335-
T=t,
334+
R=_R,
335+
T=_T,
336336
device="cpu",
337337
):
338338
"""
@@ -436,7 +436,7 @@ def get_projection_transform(self, **kwargs) -> Transform3d:
436436
P[:, 2, 2] = z_sign * zfar / (zfar - znear)
437437
P[:, 2, 3] = -(zfar * znear) / (zfar - znear)
438438

439-
# Transpose the projection matrix as PyTorch3d transforms use row vectors.
439+
# Transpose the projection matrix as PyTorch3D transforms use row vectors.
440440
transform = Transform3d(device=self.device)
441441
transform._matrix = P.transpose(1, 2).contiguous()
442442
return transform
@@ -494,8 +494,8 @@ def OpenGLOrthographicCameras(
494494
left=-1.0,
495495
right=1.0,
496496
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
497-
R=r,
498-
T=t,
497+
R=_R,
498+
T=_T,
499499
device="cpu",
500500
):
501501
"""
@@ -540,8 +540,8 @@ def __init__(
540540
max_x=1.0,
541541
min_x=-1.0,
542542
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
543-
R=r,
544-
T=t,
543+
R=_R,
544+
T=_T,
545545
device="cpu",
546546
):
547547
"""
@@ -688,7 +688,7 @@ def unproject_points(
688688

689689

690690
def SfMPerspectiveCameras(
691-
focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
691+
focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_R, device="cpu"
692692
):
693693
"""
694694
SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead.
@@ -747,8 +747,8 @@ def __init__(
747747
self,
748748
focal_length=1.0,
749749
principal_point=((0.0, 0.0),),
750-
R=r,
751-
T=t,
750+
R=_R,
751+
T=_T,
752752
device="cpu",
753753
image_size=((-1, -1),),
754754
):
@@ -848,7 +848,7 @@ def unproject_points(
848848

849849

850850
def SfMOrthographicCameras(
851-
focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu"
851+
focal_length=1.0, principal_point=((0.0, 0.0),), R=_R, T=_T, device="cpu"
852852
):
853853
"""
854854
SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead.
@@ -906,8 +906,8 @@ def __init__(
906906
self,
907907
focal_length=1.0,
908908
principal_point=((0.0, 0.0),),
909-
R=r,
910-
T=t,
909+
R=_R,
910+
T=_T,
911911
device="cpu",
912912
image_size=((-1, -1),),
913913
):
@@ -1109,7 +1109,7 @@ def _get_sfm_calibration_matrix(
11091109
################################################
11101110

11111111

1112-
def get_world_to_view_transform(R=r, T=t) -> Transform3d:
1112+
def get_world_to_view_transform(R=_R, T=_T) -> Transform3d:
11131113
"""
11141114
This function returns a Transform3d representing the transformation
11151115
matrix to go from world space to view space by applying a rotation and

tests/bm_cameras_alignment.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
import itertools
4+
from fvcore.common.benchmark import benchmark
5+
from test_cameras_alignment import TestCamerasAlignment
6+
7+
8+
def bm_cameras_alignment() -> None:
9+
10+
case_grid = {
11+
"batch_size": [10, 100, 1000],
12+
"mode": ["centers", "extrinsics"],
13+
"estimate_scale": [False, True],
14+
}
15+
test_cases = itertools.product(*case_grid.values())
16+
kwargs_list = [dict(zip(case_grid.keys(), case)) for case in test_cases]
17+
18+
benchmark(
19+
TestCamerasAlignment.corresponding_cameras_alignment,
20+
"CORRESPONDING_CAMERAS_ALIGNMENT",
21+
kwargs_list,
22+
warmup_iters=1,
23+
)

0 commit comments

Comments
 (0)