Skip to content

Commit e5b1d6d

Browse files
davnov134facebook-github-bot
authored andcommitted
Umeyama
Summary: Umeyama estimates a rigid motion between two sets of corresponding points. Benchmark output for `bm_points_alignment` ``` Arguments key: [<allow_reflection>_<batch_size>_<dim>_<estimate_scale>_<n_points>_<use_pointclouds>] Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- CorrespodingPointsAlignment_True_1_3_True_100_False 7382 9833 68 CorrespodingPointsAlignment_True_1_3_True_10000_False 8183 10500 62 CorrespodingPointsAlignment_True_1_3_False_100_False 7301 9263 69 CorrespodingPointsAlignment_True_1_3_False_10000_False 7945 9746 64 CorrespodingPointsAlignment_True_1_20_True_100_False 13706 41623 37 CorrespodingPointsAlignment_True_1_20_True_10000_False 11044 33766 46 CorrespodingPointsAlignment_True_1_20_False_100_False 9908 28791 51 CorrespodingPointsAlignment_True_1_20_False_10000_False 9523 18680 53 CorrespodingPointsAlignment_True_10_3_True_100_False 29585 32026 17 CorrespodingPointsAlignment_True_10_3_True_10000_False 29626 36324 18 CorrespodingPointsAlignment_True_10_3_False_100_False 26013 29253 20 CorrespodingPointsAlignment_True_10_3_False_10000_False 25000 33820 20 CorrespodingPointsAlignment_True_10_20_True_100_False 40955 41592 13 CorrespodingPointsAlignment_True_10_20_True_10000_False 42087 42393 12 CorrespodingPointsAlignment_True_10_20_False_100_False 39863 40381 13 CorrespodingPointsAlignment_True_10_20_False_10000_False 40813 41699 13 CorrespodingPointsAlignment_True_100_3_True_100_False 183146 194745 3 CorrespodingPointsAlignment_True_100_3_True_10000_False 213789 231466 3 CorrespodingPointsAlignment_True_100_3_False_100_False 177805 180796 3 CorrespodingPointsAlignment_True_100_3_False_10000_False 184963 185695 3 CorrespodingPointsAlignment_True_100_20_True_100_False 347181 347325 2 CorrespodingPointsAlignment_True_100_20_True_10000_False 363259 363613 2 CorrespodingPointsAlignment_True_100_20_False_100_False 351769 352496 2 CorrespodingPointsAlignment_True_100_20_False_10000_False 375629 379818 2 CorrespodingPointsAlignment_False_1_3_True_100_False 11155 13770 45 CorrespodingPointsAlignment_False_1_3_True_10000_False 10743 13938 47 CorrespodingPointsAlignment_False_1_3_False_100_False 9578 11511 53 CorrespodingPointsAlignment_False_1_3_False_10000_False 9549 11984 53 CorrespodingPointsAlignment_False_1_20_True_100_False 13809 14183 37 CorrespodingPointsAlignment_False_1_20_True_10000_False 14084 15082 36 CorrespodingPointsAlignment_False_1_20_False_100_False 12765 14177 40 CorrespodingPointsAlignment_False_1_20_False_10000_False 12811 13096 40 CorrespodingPointsAlignment_False_10_3_True_100_False 28823 39384 18 CorrespodingPointsAlignment_False_10_3_True_10000_False 27135 27525 19 CorrespodingPointsAlignment_False_10_3_False_100_False 26236 28980 20 CorrespodingPointsAlignment_False_10_3_False_10000_False 42324 45123 12 CorrespodingPointsAlignment_False_10_20_True_100_False 723902 723902 1 CorrespodingPointsAlignment_False_10_20_True_10000_False 220007 252886 3 CorrespodingPointsAlignment_False_10_20_False_100_False 55593 71636 9 CorrespodingPointsAlignment_False_10_20_False_10000_False 44419 71861 12 CorrespodingPointsAlignment_False_100_3_True_100_False 184768 185199 3 CorrespodingPointsAlignment_False_100_3_True_10000_False 198657 213868 3 CorrespodingPointsAlignment_False_100_3_False_100_False 224598 309645 3 CorrespodingPointsAlignment_False_100_3_False_10000_False 197863 202002 3 CorrespodingPointsAlignment_False_100_20_True_100_False 293484 309459 2 CorrespodingPointsAlignment_False_100_20_True_10000_False 327253 366644 2 CorrespodingPointsAlignment_False_100_20_False_100_False 420793 422194 2 CorrespodingPointsAlignment_False_100_20_False_10000_False 462634 485542 2 CorrespodingPointsAlignment_True_1_3_True_100_True 7664 9909 66 CorrespodingPointsAlignment_True_1_3_True_10000_True 7190 8366 70 CorrespodingPointsAlignment_True_1_3_False_100_True 6549 8316 77 CorrespodingPointsAlignment_True_1_3_False_10000_True 6534 7710 77 CorrespodingPointsAlignment_True_10_3_True_100_True 29052 32940 18 CorrespodingPointsAlignment_True_10_3_True_10000_True 30526 33453 17 CorrespodingPointsAlignment_True_10_3_False_100_True 28708 32993 18 CorrespodingPointsAlignment_True_10_3_False_10000_True 30630 35973 17 CorrespodingPointsAlignment_True_100_3_True_100_True 264909 320820 3 CorrespodingPointsAlignment_True_100_3_True_10000_True 310902 322604 2 CorrespodingPointsAlignment_True_100_3_False_100_True 246832 250634 3 CorrespodingPointsAlignment_True_100_3_False_10000_True 276006 289061 2 CorrespodingPointsAlignment_False_1_3_True_100_True 11421 13757 44 CorrespodingPointsAlignment_False_1_3_True_10000_True 11199 12532 45 CorrespodingPointsAlignment_False_1_3_False_100_True 11474 15841 44 CorrespodingPointsAlignment_False_1_3_False_10000_True 10384 13188 49 CorrespodingPointsAlignment_False_10_3_True_100_True 36599 47340 14 CorrespodingPointsAlignment_False_10_3_True_10000_True 40702 50754 13 CorrespodingPointsAlignment_False_10_3_False_100_True 41277 52149 13 CorrespodingPointsAlignment_False_10_3_False_10000_True 34286 37091 15 CorrespodingPointsAlignment_False_100_3_True_100_True 254991 258578 2 CorrespodingPointsAlignment_False_100_3_True_10000_True 257999 261285 2 CorrespodingPointsAlignment_False_100_3_False_100_True 247511 248693 3 CorrespodingPointsAlignment_False_100_3_False_10000_True 251807 263865 3 ``` Reviewed By: gkioxari Differential Revision: D19808389 fbshipit-source-id: 83305a58627d2fc5dcaf3c3015132d8148f28c29
1 parent 745aaf3 commit e5b1d6d

File tree

4 files changed

+550
-0
lines changed

4 files changed

+550
-0
lines changed

pytorch3d/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .mesh_face_areas_normals import mesh_face_areas_normals
77
from .nearest_neighbor_points import nn_points_idx
88
from .packed_to_padded import packed_to_padded, padded_to_packed
9+
from .points_alignment import corresponding_points_alignment
910
from .sample_points_from_meshes import sample_points_from_meshes
1011
from .subdivide_meshes import SubdivideMeshes
1112
from .vert_align import vert_align

pytorch3d/ops/points_alignment.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3+
4+
import warnings
5+
from typing import Tuple, Union
6+
import torch
7+
8+
from pytorch3d.structures.pointclouds import Pointclouds
9+
10+
11+
def corresponding_points_alignment(
12+
X: Union[torch.Tensor, Pointclouds],
13+
Y: Union[torch.Tensor, Pointclouds],
14+
estimate_scale: bool = False,
15+
allow_reflection: bool = False,
16+
eps: float = 1e-8,
17+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
18+
"""
19+
Finds a similarity transformation (rotation `R`, translation `T`
20+
and optionally scale `s`) between two given sets of corresponding
21+
`d`-dimensional points `X` and `Y` such that:
22+
23+
`s[i] X[i] R[i] + T[i] = Y[i]`,
24+
25+
for all batch indexes `i` in the least squares sense.
26+
27+
The algorithm is also known as Umeyama [1].
28+
29+
Args:
30+
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
31+
or a `Pointclouds` object.
32+
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
33+
or a `Pointclouds` object.
34+
estimate_scale: If `True`, also estimates a scaling component `s`
35+
of the transformation. Otherwise assumes an identity
36+
scale and returns a tensor of ones.
37+
allow_reflection: If `True`, allows the algorithm to return `R`
38+
which is orthonormal but has determinant==-1.
39+
eps: A scalar for clamping to avoid dividing by zero. Active for the
40+
code that estimates the output scale `s`.
41+
42+
Returns:
43+
3-element tuple containing
44+
- **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
45+
- **T**: Batch of translations of shape `(minibatch, d)`.
46+
- **s**: batch of scaling factors of shape `(minibatch, )`.
47+
48+
References:
49+
[1] Shinji Umeyama: Least-Suqares Estimation of
50+
Transformation Parameters Between Two Point Patterns
51+
"""
52+
53+
# make sure we convert input Pointclouds structures to tensors
54+
Xt, num_points = _convert_point_cloud_to_tensor(X)
55+
Yt, num_points_Y = _convert_point_cloud_to_tensor(Y)
56+
57+
if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
58+
raise ValueError(
59+
"Point sets X and Y have to have the same \
60+
number of batches, points and dimensions."
61+
)
62+
63+
b, n, dim = Xt.shape
64+
65+
# compute the centroids of the point sets
66+
Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1)
67+
Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1)
68+
69+
# mean-center the point sets
70+
Xc = Xt - Xmu[:, None]
71+
Yc = Yt - Ymu[:, None]
72+
73+
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
74+
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
75+
mask = (
76+
torch.arange(n, dtype=torch.int64, device=Xc.device)[None]
77+
< num_points[:, None]
78+
).type_as(Xc)
79+
Xc *= mask[:, :, None]
80+
Yc *= mask[:, :, None]
81+
82+
if (num_points < (dim + 1)).any():
83+
warnings.warn(
84+
"The size of one of the point clouds is <= dim+1. "
85+
+ "corresponding_points_alignment can't return a unique solution."
86+
)
87+
88+
# compute the covariance XYcov between the point sets Xc, Yc
89+
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
90+
XYcov = XYcov / torch.clamp(num_points[:, None, None], 1)
91+
92+
# decompose the covariance matrix XYcov
93+
U, S, V = torch.svd(XYcov)
94+
95+
# identity matrix used for fixing reflections
96+
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(
97+
b, 1, 1
98+
)
99+
100+
if not allow_reflection:
101+
# reflection test:
102+
# checks whether the estimated rotation has det==1,
103+
# if not, finds the nearest rotation s.t. det==1 by
104+
# flipping the sign of the last singular vector U
105+
R_test = torch.bmm(U, V.transpose(2, 1))
106+
E[:, -1, -1] = torch.det(R_test)
107+
108+
# find the rotation matrix by composing U and V again
109+
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
110+
111+
if estimate_scale:
112+
# estimate the scaling component of the transformation
113+
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
114+
Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1)
115+
116+
# the scaling component
117+
s = trace_ES / torch.clamp(Xcov, eps)
118+
119+
# translation component
120+
T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :]
121+
122+
else:
123+
# translation component
124+
T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0]
125+
126+
# unit scaling since we do not estimate scale
127+
s = T.new_ones(b)
128+
129+
return R, T, s
130+
131+
132+
def _convert_point_cloud_to_tensor(pcl: Union[torch.Tensor, Pointclouds]):
133+
"""
134+
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
135+
padded representation and returns it together with the number of points
136+
per batch. Otherwise, returns the input itself with the number of points
137+
set to the size of the second dimension of `pcl`.
138+
"""
139+
if isinstance(pcl, Pointclouds):
140+
X = pcl.points_padded()
141+
num_points = pcl.num_points_per_cloud()
142+
elif torch.is_tensor(pcl):
143+
X = pcl
144+
num_points = X.shape[1] * torch.ones(
145+
X.shape[0], device=X.device, dtype=torch.int64
146+
)
147+
else:
148+
raise ValueError(
149+
"The inputs X, Y should be either Pointclouds objects or tensors."
150+
)
151+
return X, num_points

tests/bm_points_alignment.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3+
4+
from copy import deepcopy
5+
from itertools import product
6+
from fvcore.common.benchmark import benchmark
7+
8+
from test_points_alignment import TestCorrespondingPointsAlignment
9+
10+
11+
def bm_corresponding_points_alignment() -> None:
12+
13+
case_grid = {
14+
"allow_reflection": [True, False],
15+
"batch_size": [1, 10, 100],
16+
"dim": [3, 20],
17+
"estimate_scale": [True, False],
18+
"n_points": [100, 10000],
19+
"use_pointclouds": [False],
20+
}
21+
22+
test_args = sorted(case_grid.keys())
23+
test_cases = product(*[case_grid[k] for k in test_args])
24+
kwargs_list = [dict(zip(test_args, case)) for case in test_cases]
25+
26+
# add the use_pointclouds=True test cases whenever we have dim==3
27+
kwargs_to_add = []
28+
for entry in kwargs_list:
29+
if entry["dim"] == 3:
30+
entry_add = deepcopy(entry)
31+
entry_add["use_pointclouds"] = True
32+
kwargs_to_add.append(entry_add)
33+
kwargs_list.extend(kwargs_to_add)
34+
35+
benchmark(
36+
TestCorrespondingPointsAlignment.corresponding_points_alignment,
37+
"CorrespodingPointsAlignment",
38+
kwargs_list,
39+
warmup_iters=1,
40+
)

0 commit comments

Comments
 (0)