Skip to content

Commit 4bfe715

Browse files
bottlerfacebook-github-bot
authored andcommitted
mesh_normal_consistency speedup
Summary: One step in finding all the pairs of vertices which share faces is a simple calculation but annoying to parallelize. It was implemented in pure Python. We move it to C++. We still pull the data to the CPU and put the answer back on the device. Reviewed By: nikhilaravi, gkioxari Differential Revision: D26073475 fbshipit-source-id: ffbf4e2c347a511ab5084bceff600465812b6a52
1 parent 5ac2f42 commit 4bfe715

File tree

4 files changed

+84
-28
lines changed

4 files changed

+84
-28
lines changed

pytorch3d/csrc/ext.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "gather_scatter/gather_scatter.h"
1515
#include "interp_face_attrs/interp_face_attrs.h"
1616
#include "knn/knn.h"
17+
#include "mesh_normal_consistency/mesh_normal_consistency.h"
1718
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
1819
#include "point_mesh/point_mesh_cuda.h"
1920
#include "rasterize_meshes/rasterize_meshes.h"
@@ -31,6 +32,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
3132
#endif
3233
m.def("knn_points_idx", &KNearestNeighborIdx);
3334
m.def("knn_points_backward", &KNearestNeighborBackward);
35+
m.def(
36+
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
3437
m.def("gather_scatter", &GatherScatter);
3538
m.def("rasterize_points", &RasterizePoints);
3639
m.def("rasterize_points_backward", &RasterizePointsBackward);
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#pragma once
4+
#include <torch/extension.h>
5+
#include "utils/pytorch3d_cutils.h"
6+
7+
// For mesh_normal_consistency, find pairs of vertices opposite the same edge.
8+
//
9+
// Args:
10+
// edge_num: int64 Tensor of shape (E,) giving the number of vertices
11+
// corresponding to each edge.
12+
//
13+
// Returns:
14+
// pairs: int64 Tensor of shape (N,2)
15+
16+
at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num);
17+
18+
// Exposed implementation.
19+
at::Tensor MeshNormalConsistencyFindVertices(const at::Tensor& edge_num) {
20+
if (edge_num.is_cuda()) {
21+
AT_ERROR("This function needs a CPU tensor.");
22+
}
23+
return MeshNormalConsistencyFindVerticesCpu(edge_num);
24+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
#include <ATen/ATen.h>
4+
#include <utility>
5+
#include <vector>
6+
7+
at::Tensor MeshNormalConsistencyFindVerticesCpu(const at::Tensor& edge_num) {
8+
// We take a LongTensor of shape (E,) giving the number of things intersecting
9+
// each edge. The things are taken to be numbered in order.
10+
// (In fact, the "things" are opposite vertices to edges, renumbered).
11+
// We return a tensor of shape (?, 2) where for every pair of things which
12+
// intersect the same edge there is a row of their numbers in the output.
13+
14+
// Example possible inputs and outputs (order of output is not specified):
15+
// [1,0,1,1,0] => [[]]
16+
// [3] => [[0,1], [0,2], [1,2]]
17+
// [0,3] => [[0,1], [0,2], [1,2]]
18+
// [1,3] => [[1,2], [1,3], [2,3]]
19+
//[1,0,2,1,0,2] => [[1,2], [4,5]]
20+
21+
const auto num_edges = edge_num.size(0);
22+
auto edges_a = edge_num.accessor<int64_t, 1>();
23+
24+
int64_t vert_idx = 0;
25+
std::vector<std::pair<int64_t, int64_t>> pairs;
26+
for (int64_t i_edge = 0; i_edge < num_edges; ++i_edge) {
27+
int64_t e = edges_a[i_edge];
28+
for (int64_t j = 0; j < e; ++j) {
29+
for (int64_t i = 0; i < j; ++i) {
30+
pairs.emplace_back(vert_idx + i, vert_idx + j);
31+
}
32+
}
33+
vert_idx += e;
34+
}
35+
36+
// Convert from std::vector by copying over the items to a new empty torch
37+
// tensor.
38+
auto pairs_tensor = at::empty({(int64_t)pairs.size(), 2}, edge_num.options());
39+
auto pairs_a = pairs_tensor.accessor<int64_t, 2>();
40+
for (int64_t i_pair = 0; i_pair < pairs.size(); ++i_pair) {
41+
auto accessor = pairs_a[i_pair];
42+
accessor[0] = pairs[i_pair].first;
43+
accessor[1] = pairs[i_pair].second;
44+
}
45+
46+
return pairs_tensor;
47+
}

pytorch3d/loss/mesh_normal_consistency.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3-
4-
from itertools import islice
5-
63
import torch
74

5+
# pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
6+
from pytorch3d import _C
7+
88

99
def mesh_normal_consistency(meshes):
1010
r"""
@@ -71,9 +71,9 @@ def mesh_normal_consistency(meshes):
7171
F = faces_packed.shape[0] # sum(F_n)
7272

7373
# We don't want gradients for the following operation. The goal is to
74-
# find for each edge e all the vertices associated with e. In the example above,
75-
# the vertices associated with e are (v0, v1, a, b), i.e. points on e (=v0, v1)
76-
# and points connected on faces to e (=a, b).
74+
# find for each edge e all the vertices associated with e. In the example
75+
# above, the vertices associated with e are (a, b), i.e. the points connected
76+
# on faces to e.
7777
with torch.no_grad():
7878
edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges
7979
vert_idx = (
@@ -95,23 +95,10 @@ def mesh_normal_consistency(meshes):
9595
# the number of vertices which are associated with each edge.
9696
# There can be a different number for each edge.
9797
edge_num = edge_idx.bincount(minlength=E)
98-
# Create pairs of vertices associated to e. We generate a list of lists:
99-
# each list has the indices of the vertices which are opposite to one edge.
100-
# The length of the list for each edge will vary.
101-
vert_edge_pair_idx = split_list(
102-
list(range(edge_idx.shape[0])), edge_num.tolist()
103-
)
104-
# For each list find all combinations of pairs in the list. This represents
105-
# all pairs of vertices which are opposite to the same edge.
106-
vert_edge_pair_idx = [
107-
[e[i], e[j]]
108-
for e in vert_edge_pair_idx
109-
for i in range(len(e) - 1)
110-
for j in range(1, len(e))
111-
if i < j
112-
]
113-
vert_edge_pair_idx = torch.tensor(
114-
vert_edge_pair_idx, device=meshes.device, dtype=torch.int64
98+
99+
# This calculates all pairs of vertices which are opposite to the same edge.
100+
vert_edge_pair_idx = _C.mesh_normal_consistency_find_verts(edge_num.cpu()).to(
101+
edge_num.device
115102
)
116103

117104
if vert_edge_pair_idx.shape[0] == 0:
@@ -141,8 +128,3 @@ def mesh_normal_consistency(meshes):
141128

142129
loss = loss * weights
143130
return loss.sum() / N
144-
145-
146-
def split_list(input, length_to_split):
147-
inputt = iter(input)
148-
return [list(islice(inputt, elem)) for elem in length_to_split]

0 commit comments

Comments
 (0)