Skip to content

Commit ebac66d

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Classic Marching Cubes algorithm implementation
Summary: Defines a function to run marching cubes algorithm on a single or batch of 3D scalar fields. Returns a mesh's faces and vertices. UPDATES (12/18) - Input data is now specified as a (B, D, H, W) tensor as opposed to a (B, W, H, D) tensor. This will now be compatible with the Volumes datastructure. - Add an option to return output vertices in local coordinates instead of world coordinates. Also added a small fix to remove the dype for device in Transforms3D - if passing in a torch.device instead of str it causes a pyre error. Reviewed By: jcjohnson Differential Revision: D24599019 fbshipit-source-id: 90554a200319fed8736a12371cc349e7108aacd0
1 parent 9c6b58c commit ebac66d

File tree

7 files changed

+1693
-5
lines changed

7 files changed

+1693
-5
lines changed

pytorch3d/ops/marching_cubes.py

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
from typing import Dict, List, Optional, Tuple
4+
5+
import torch
6+
from pytorch3d.ops.marching_cubes_data import EDGE_TABLE, EDGE_TO_VERTICES, FACE_TABLE
7+
from pytorch3d.transforms import Translate
8+
9+
10+
EPS = 0.00001
11+
12+
13+
class Cube:
14+
def __init__(self, bfl_vertex: Tuple[int, int, int], spacing: int = 1):
15+
"""
16+
Initializes a cube given the bottom front left vertex coordinate
17+
and the cube spacing
18+
19+
Edge and vertex convention:
20+
21+
v4_______e4____________v5
22+
/| /|
23+
/ | / |
24+
e7/ | e5/ |
25+
/___|______e6_________/ |
26+
v7| | |v6 |e9
27+
| | | |
28+
| |e8 |e10|
29+
e11| | | |
30+
| |_________________|___|
31+
| / v0 e0 | /v1
32+
| / | /
33+
| /e3 | /e1
34+
|/_____________________|/
35+
v3 e2 v2
36+
37+
Args:
38+
bfl_vertex: a tuple of size 3 corresponding to the bottom front left vertex
39+
of the cube in (x, y, z) format
40+
spacing: the length of each edge of the cube
41+
"""
42+
# match corner orders to algorithm convention
43+
if len(bfl_vertex) != 3:
44+
msg = "The vertex {} is size {} instead of size 3".format(
45+
bfl_vertex, len(bfl_vertex)
46+
)
47+
raise ValueError(msg)
48+
49+
x, y, z = bfl_vertex
50+
self.vertices = torch.tensor(
51+
[
52+
[x, y, z + spacing],
53+
[x + spacing, y, z + spacing],
54+
[x + spacing, y, z],
55+
[x, y, z],
56+
[x, y + spacing, z + spacing],
57+
[x + spacing, y + spacing, z + spacing],
58+
[x + spacing, y + spacing, z],
59+
[x, y + spacing, z],
60+
]
61+
)
62+
63+
def get_index(self, volume_data: torch.Tensor, isolevel: float) -> int:
64+
"""
65+
Calculates the cube_index in the range 0-255 to index
66+
into EDGE_TABLE and FACE_TABLE
67+
Args:
68+
volume_data: the 3D scalar data
69+
isolevel: the isosurface value used as a threshold
70+
for determining whether a point is inside/outside
71+
the volume
72+
"""
73+
cube_index = 0
74+
bit = 1
75+
for index in range(len(self.vertices)):
76+
vertex = self.vertices[index]
77+
value = _get_value(vertex, volume_data)
78+
if value < isolevel:
79+
cube_index |= bit
80+
bit *= 2
81+
return cube_index
82+
83+
84+
def marching_cubes_naive(
85+
volume_data_batch: torch.Tensor,
86+
isolevel: Optional[float] = None,
87+
spacing: int = 1,
88+
return_local_coords: bool = True,
89+
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
90+
"""
91+
Runs the classic marching cubes algorithm, iterating over
92+
the coordinates of the volume_data and using a given isolevel
93+
for determining intersected edges of cubes of size `spacing`.
94+
Returns vertices and faces of the obtained mesh.
95+
This operation is non-differentiable.
96+
97+
This is a naive implementation, and is not optimized for efficiency.
98+
99+
Args:
100+
volume_data_batch: a Tensor of size (N, D, H, W) corresponding to
101+
a batch of 3D scalar fields
102+
isolevel: the isosurface value to use as the threshold to determine
103+
whether points are within a volume. If None, then the average of the
104+
maximum and minimum value of the scalar field will be used.
105+
spacing: an integer specifying the cube size to use
106+
return_local_coords: bool. If True the output vertices will be in local coordinates in
107+
the range [-1, 1] x [-1, 1] x [-1, 1]. If False they will be in the range
108+
[0, W-1] x [0, H-1] x [0, D-1]
109+
Returns:
110+
verts: [(V_0, 3), (V_1, 3), ...] List of N FloatTensors of vertices.
111+
faces: [(F_0, 3), (F_1, 3), ...] List of N LongTensors of faces.
112+
"""
113+
volume_data_batch = volume_data_batch.detach().cpu()
114+
batched_verts, batched_faces = [], []
115+
D, H, W = volume_data_batch.shape[1:]
116+
# pyre-ignore [16]
117+
volume_size_xyz = volume_data_batch.new_tensor([W, H, D])[None]
118+
119+
if return_local_coords:
120+
# Convert from local coordinates in the range [-1, 1] range to
121+
# world coordinates in the range [0, D-1], [0, H-1], [0, W-1]
122+
local_to_world_transform = Translate(
123+
x=+1.0, y=+1.0, z=+1.0, device=volume_data_batch.device
124+
).scale((volume_size_xyz - 1) * spacing * 0.5)
125+
# Perform the inverse to go from world to local
126+
world_to_local_transform = local_to_world_transform.inverse()
127+
128+
for i in range(len(volume_data_batch)):
129+
volume_data = volume_data_batch[i]
130+
curr_isolevel = (
131+
((volume_data.max() + volume_data.min()) / 2).item()
132+
if isolevel is None
133+
else isolevel
134+
)
135+
edge_vertices_to_index = {}
136+
vertex_coords_to_index = {}
137+
verts, faces = [], []
138+
# Use length - spacing for the bounds since we are using
139+
# cubes of size spacing, with the lowest x,y,z values
140+
# (bottom front left)
141+
for x in range(0, W - spacing, spacing):
142+
for y in range(0, H - spacing, spacing):
143+
for z in range(0, D - spacing, spacing):
144+
cube = Cube((x, y, z), spacing)
145+
new_verts, new_faces = polygonise(
146+
cube,
147+
curr_isolevel,
148+
volume_data,
149+
edge_vertices_to_index,
150+
vertex_coords_to_index,
151+
)
152+
verts.extend(new_verts)
153+
faces.extend(new_faces)
154+
if len(faces) > 0 and len(verts) > 0:
155+
verts = torch.tensor(verts, dtype=torch.float32)
156+
# Convert vertices from world to local coords
157+
if return_local_coords:
158+
verts = world_to_local_transform.transform_points(verts[None, ...])
159+
verts = verts.squeeze()
160+
batched_verts.append(verts)
161+
batched_faces.append(torch.tensor(faces, dtype=torch.int64))
162+
return batched_verts, batched_faces
163+
164+
165+
def polygonise(
166+
cube: Cube,
167+
isolevel: float,
168+
volume_data: torch.Tensor,
169+
edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],
170+
vertex_coords_to_index: Dict[Tuple[float, float, float], int],
171+
) -> Tuple[list, list]:
172+
"""
173+
Runs the classic marching cubes algorithm for one Cube in the volume.
174+
Returns the vertices and faces for the given cube.
175+
176+
Args:
177+
cube: a Cube indicating the cube being examined for edges that intersect
178+
the volume data.
179+
isolevel: the isosurface value to use as the threshold to determine
180+
whether points are within a volume.
181+
volume_data: a Tensor of shape (D, H, W) corresponding to
182+
a 3D scalar field
183+
edge_vertices_to_index: A dictionary which maps an edge's two coordinates
184+
to the index of its interpolated point, if that interpolated point
185+
has already been used by a previous point
186+
vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding
187+
index of that vertex, if that point has already been marked as a vertex.
188+
Returns:
189+
verts: List of triangle vertices for the given cube in the volume
190+
faces: List of triangle faces for the given cube in the volume
191+
"""
192+
num_existing_verts = max(edge_vertices_to_index.values(), default=-1) + 1
193+
verts, faces = [], []
194+
cube_index = cube.get_index(volume_data, isolevel)
195+
edges = EDGE_TABLE[cube_index]
196+
edge_indices = _get_edge_indices(edges)
197+
if len(edge_indices) == 0:
198+
return [], []
199+
200+
new_verts, edge_index_to_point_index = _calculate_interp_vertices(
201+
edge_indices,
202+
volume_data,
203+
cube,
204+
isolevel,
205+
edge_vertices_to_index,
206+
vertex_coords_to_index,
207+
num_existing_verts,
208+
)
209+
210+
# Create faces
211+
face_triangles = FACE_TABLE[cube_index]
212+
for i in range(0, len(face_triangles), 3):
213+
tri1 = edge_index_to_point_index[face_triangles[i]]
214+
tri2 = edge_index_to_point_index[face_triangles[i + 1]]
215+
tri3 = edge_index_to_point_index[face_triangles[i + 2]]
216+
if tri1 != tri2 and tri2 != tri3 and tri1 != tri3:
217+
faces.append([tri1, tri2, tri3])
218+
219+
verts += new_verts
220+
return verts, faces
221+
222+
223+
def _get_edge_indices(edges: int) -> List[int]:
224+
"""
225+
Finds which edge numbers are intersected given the bit representation
226+
detailed in marching_cubes_data.EDGE_TABLE.
227+
228+
Args:
229+
edges: an integer corresponding to the value at cube_index
230+
from the EDGE_TABLE in marching_cubes_data.py
231+
232+
Returns:
233+
edge_indices: A list of edge indices
234+
"""
235+
if edges == 0:
236+
return []
237+
238+
edge_indices = []
239+
for i in range(12):
240+
if edges & (2 ** i):
241+
edge_indices.append(i)
242+
return edge_indices
243+
244+
245+
def _calculate_interp_vertices(
246+
edge_indices: List[int],
247+
volume_data: torch.Tensor,
248+
cube: Cube,
249+
isolevel: float,
250+
edge_vertices_to_index: Dict[Tuple[Tuple, Tuple], int],
251+
vertex_coords_to_index: Dict[Tuple[float, float, float], int],
252+
num_existing_verts: int,
253+
) -> Tuple[List, Dict[int, int]]:
254+
"""
255+
Finds the interpolated vertices for the intersected edges, either referencing
256+
previous calculations or newly calculating and storing the new interpolated
257+
points.
258+
259+
Args:
260+
edge_indices: the numbers of the edges which are intersected. See the
261+
Cube class for more detail on the edge numbering convention.
262+
volume_data: a Tensor of size (D, H, W) corresponding to
263+
a 3D scalar field
264+
cube: a Cube indicating the cube being examined for edges that intersect
265+
the volume
266+
isolevel: the isosurface value to use as the threshold to determine
267+
whether points are within a volume.
268+
edge_vertices_to_index: A dictionary which maps an edge's two coordinates
269+
to the index of its interpolated point, if that interpolated point
270+
has already been used by a previous point
271+
vertex_coords_to_index: A dictionary mapping a point (x, y, z) to the corresponding
272+
index of that vertex, if that point has already been marked as a vertex.
273+
num_existing_verts: the number of vertices that have been found in previous
274+
calls to polygonise for the given volume_data in the above function, marching_cubes.
275+
This is equal to the 1 + the maximum value in edge_vertices_to_index.
276+
Returns:
277+
interp_points: a list of new interpolated points
278+
edge_index_to_point_index: a dictionary mapping an edge number to the index in the
279+
marching cubes' vertices list of the interpolated point on that edge. To be precise,
280+
it refers to the index within the vertices list after interp_points
281+
has been appended to the verts list constructed in the marching_cubes_naive
282+
function.
283+
"""
284+
interp_points = []
285+
edge_index_to_point_index = {}
286+
for edge_index in edge_indices:
287+
v1, v2 = EDGE_TO_VERTICES[edge_index]
288+
point1, point2 = cube.vertices[v1], cube.vertices[v2]
289+
p_tuple1, p_tuple2 = tuple(point1.tolist()), tuple(point2.tolist())
290+
if (p_tuple1, p_tuple2) in edge_vertices_to_index:
291+
edge_index_to_point_index[edge_index] = edge_vertices_to_index[
292+
(p_tuple1, p_tuple2)
293+
]
294+
else:
295+
val1, val2 = _get_value(point1, volume_data), _get_value(
296+
point2, volume_data
297+
)
298+
299+
point = None
300+
if abs(isolevel - val1) < EPS:
301+
point = point1
302+
303+
if abs(isolevel - val2) < EPS:
304+
point = point2
305+
306+
if abs(val1 - val2) < EPS:
307+
point = point1
308+
309+
if point is None:
310+
mu = (isolevel - val1) / (val2 - val1)
311+
x1, y1, z1 = point1
312+
x2, y2, z2 = point2
313+
x = x1 + mu * (x2 - x1)
314+
y = y1 + mu * (y2 - y1)
315+
z = z1 + mu * (z2 - z1)
316+
else:
317+
x, y, z = point
318+
319+
x, y, z = x.item(), y.item(), z.item() # for dictionary keys
320+
321+
vert_index = None
322+
if (x, y, z) in vertex_coords_to_index:
323+
vert_index = vertex_coords_to_index[(x, y, z)]
324+
else:
325+
vert_index = num_existing_verts + len(interp_points)
326+
interp_points.append([x, y, z])
327+
vertex_coords_to_index[(x, y, z)] = vert_index
328+
329+
edge_vertices_to_index[(p_tuple1, p_tuple2)] = vert_index
330+
edge_index_to_point_index[edge_index] = vert_index
331+
332+
return interp_points, edge_index_to_point_index
333+
334+
335+
def _get_value(point: Tuple[int, int, int], volume_data: torch.Tensor) -> float:
336+
"""
337+
Gets the value at a given coordinate point in the scalar field.
338+
339+
Args:
340+
point: data of shape (3) corresponding to an xyz coordinate.
341+
volume_data: a Tensor of size (D, H, W) corresponding to
342+
a 3D scalar field
343+
Returns:
344+
data: scalar value in the volume at the given point
345+
"""
346+
x, y, z = point
347+
return volume_data[z][y][x]

0 commit comments

Comments
 (0)