Skip to content

Commit e12a081

Browse files
bottlerfacebook-github-bot
authored andcommitted
Deduplicate texture maps when joining
Summary: If you join several meshes which have TexturesUV textures using join_meshes_as_scene then we amalgamate all the texture images in to a single one. This now checks if some of the images are equal (i.e. the tensors are the same tensor, in the `is` sense; they have the same `id` in Python) and only uses one copy if they are. I have an example of a massive scene made of several textured meshes with some shared, where this makes the difference between fitting the data on the GPU and not. Reviewed By: theschnitz Differential Revision: D25982364 fbshipit-source-id: a8228805f38475c796302e27328a340d9b56c8ef
1 parent cd5af25 commit e12a081

File tree

4 files changed

+155
-42
lines changed

4 files changed

+155
-42
lines changed

pytorch3d/renderer/mesh/textures.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pytorch3d.structures.utils import list_to_packed, list_to_padded, padded_to_list
1111
from torch.nn.functional import interpolate
1212

13-
from .utils import pack_rectangles
13+
from .utils import PackedRectangle, Rectangle, pack_unique_rectangles
1414

1515

1616
# This file contains classes and helper functions for texturing.
@@ -1028,14 +1028,13 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
10281028
maps_list = []
10291029
faces_uvs_list += self.faces_uvs_list()
10301030
verts_uvs_list += self.verts_uvs_list()
1031-
maps_list += list(self.maps_padded().unbind(0))
1031+
maps_list += self.maps_list()
10321032
num_faces_per_mesh = self._num_faces_per_mesh
10331033
for tex in textures:
10341034
verts_uvs_list += tex.verts_uvs_list()
10351035
faces_uvs_list += tex.faces_uvs_list()
10361036
num_faces_per_mesh += tex._num_faces_per_mesh
1037-
tex_map_list = list(tex.maps_padded().unbind(0))
1038-
maps_list += tex_map_list
1037+
maps_list += tex.maps_list()
10391038

10401039
new_tex = self.__class__(
10411040
maps=maps_list,
@@ -1048,10 +1047,7 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
10481047
return new_tex
10491048

10501049
def _place_map_into_single_map(
1051-
self,
1052-
single_map: torch.Tensor,
1053-
map_: torch.Tensor,
1054-
location: Tuple[int, int, bool], # (x,y) and whether flipped
1050+
self, single_map: torch.Tensor, map_: torch.Tensor, location: PackedRectangle
10551051
) -> None:
10561052
"""
10571053
Copy map into a larger tensor single_map at the destination specified by location.
@@ -1064,11 +1060,11 @@ def _place_map_into_single_map(
10641060
map_: (H, W, 3) source data
10651061
location: where to place map
10661062
"""
1067-
do_flip = location[2]
1063+
do_flip = location.flipped
10681064
source = map_.transpose(0, 1) if do_flip else map_
10691065
border_width = 0 if self.align_corners else 1
1070-
lower_u = location[0] + border_width
1071-
lower_v = location[1] + border_width
1066+
lower_u = location.x + border_width
1067+
lower_v = location.y + border_width
10721068
upper_u = lower_u + source.shape[0]
10731069
upper_v = lower_v + source.shape[1]
10741070
single_map[lower_u:upper_u, lower_v:upper_v] = source
@@ -1102,28 +1098,33 @@ def join_scene(self) -> "TexturesUV":
11021098
If align_corners=False, we need to add an artificial border around
11031099
every map.
11041100
1105-
We use the function `pack_rectangles` to provide a layout for the
1106-
single map. _place_map_into_single_map is used to copy the maps
1107-
into the single map. The merging of verts_uvs and faces_uvs are
1108-
handled locally in this function.
1101+
We use the function `pack_unique_rectangles` to provide a layout for
1102+
the single map. This means that if self was created with a list of maps,
1103+
and to() has not been called, and there were two maps which were exactly
1104+
the same tensor object, then they will become the same data in the unified map.
1105+
_place_map_into_single_map is used to copy the maps into the single map.
1106+
The merging of verts_uvs and faces_uvs is handled locally in this function.
11091107
"""
11101108
maps = self.maps_list()
11111109
heights_and_widths = []
11121110
extra_border = 0 if self.align_corners else 2
11131111
for map_ in maps:
11141112
heights_and_widths.append(
1115-
(map_.shape[0] + extra_border, map_.shape[1] + extra_border)
1113+
Rectangle(
1114+
map_.shape[0] + extra_border, map_.shape[1] + extra_border, id(map_)
1115+
)
11161116
)
1117-
merging_plan = pack_rectangles(heights_and_widths)
1117+
merging_plan = pack_unique_rectangles(heights_and_widths)
11181118
# pyre-fixme[16]: `Tensor` has no attribute `new_zeros`.
11191119
single_map = maps[0].new_zeros((*merging_plan.total_size, 3))
11201120
verts_uvs = self.verts_uvs_list()
11211121
verts_uvs_merged = []
11221122

11231123
for map_, loc, uvs in zip(maps, merging_plan.locations, verts_uvs):
11241124
new_uvs = uvs.clone()
1125-
self._place_map_into_single_map(single_map, map_, loc)
1126-
do_flip = loc[2]
1125+
if loc.is_first:
1126+
self._place_map_into_single_map(single_map, map_, loc)
1127+
do_flip = loc.flipped
11271128
x_shape = map_.shape[1] if do_flip else map_.shape[0]
11281129
y_shape = map_.shape[0] if do_flip else map_.shape[1]
11291130

@@ -1164,9 +1165,9 @@ def join_scene(self) -> "TexturesUV":
11641165
denom_y = merging_plan.total_size[1] - one_if_align
11651166
scale_y = y_shape - one_if_align
11661167
new_uvs[:, 1] *= scale_x / denom_x
1167-
new_uvs[:, 1] += (loc[0] + one_if_not_align) / denom_x
1168+
new_uvs[:, 1] += (loc.x + one_if_not_align) / denom_x
11681169
new_uvs[:, 0] *= scale_y / denom_y
1169-
new_uvs[:, 0] += (loc[1] + one_if_not_align) / denom_y
1170+
new_uvs[:, 0] += (loc.y + one_if_not_align) / denom_y
11701171

11711172
verts_uvs_merged.append(new_uvs)
11721173

pytorch3d/renderer/mesh/utils.py

Lines changed: 84 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,25 @@ def _interpolate_zbuf(
6464

6565
# ----------- Rectangle Packing -------------------- #
6666

67+
68+
class Rectangle(NamedTuple):
69+
xsize: int
70+
ysize: int
71+
identifier: int
72+
73+
74+
class PackedRectangle(NamedTuple):
75+
x: int
76+
y: int
77+
flipped: bool
78+
is_first: bool
79+
80+
81+
class PackedRectangles(NamedTuple):
82+
total_size: Tuple[int, int]
83+
locations: List[PackedRectangle]
84+
85+
6786
# Note the order of members matters here because it determines the queue order.
6887
# We want to place longer rectangles first.
6988
class _UnplacedRectangle(NamedTuple):
@@ -74,7 +93,7 @@ class _UnplacedRectangle(NamedTuple):
7493

7594
def _try_place_rectangle(
7695
rect: _UnplacedRectangle,
77-
placed_so_far: List[Tuple[int, int, bool]],
96+
placed_so_far: List[PackedRectangle],
7897
occupied: List[Tuple[int, int]],
7998
) -> bool:
8099
"""
@@ -156,10 +175,11 @@ def _try_place_rectangle(
156175
current_start_idx = idx
157176
if currently_packed >= needed_height:
158177
current_max_width = max(interval[0], current_max_width)
159-
placed_so_far[rect.ind] = (
178+
placed_so_far[rect.ind] = PackedRectangle(
160179
current_max_width,
161180
occupied[current_start_idx - 1][1],
162181
rect.flipped,
182+
True,
163183
)
164184
new_occupied = (
165185
current_max_width + rect.size[0],
@@ -182,11 +202,6 @@ def _try_place_rectangle(
182202
return False
183203

184204

185-
class PackedRectangles(NamedTuple):
186-
total_size: Tuple[int, int]
187-
locations: List[Tuple[int, int, bool]] # (x,y) and whether flipped
188-
189-
190205
def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
191206
"""
192207
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
@@ -200,7 +215,9 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
200215
201216
Returns:
202217
total_size: size of total large rectangle
203-
rectangles: location for each of the input rectangles
218+
rectangles: location for each of the input rectangles.
219+
This includes whether they are flipped.
220+
The is_first field is always True.
204221
"""
205222

206223
if len(sizes) < 2:
@@ -213,14 +230,14 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
213230
else:
214231
queue.append(_UnplacedRectangle((size[0], size[1]), i, False))
215232
queue.sort()
216-
placed_so_far = [(-1, -1, False)] * len(sizes)
233+
placed_so_far = [PackedRectangle(-1, -1, False, False)] * len(sizes)
217234

218235
biggest = queue.pop()
219236
total_width, current_height = biggest.size
220-
placed_so_far[biggest.ind] = (0, 0, biggest.flipped)
237+
placed_so_far[biggest.ind] = PackedRectangle(0, 0, biggest.flipped, True)
221238

222239
second = queue.pop()
223-
placed_so_far[second.ind] = (0, current_height, second.flipped)
240+
placed_so_far[second.ind] = PackedRectangle(0, current_height, second.flipped, True)
224241
current_height += second.size[1]
225242
occupied = [biggest.size, (second.size[0], current_height)]
226243

@@ -236,8 +253,63 @@ def pack_rectangles(sizes: List[Tuple[int, int]]) -> PackedRectangles:
236253

237254
# rect wasn't placed in the current bounding box,
238255
# so we add extra space to fit it in.
239-
placed_so_far[rect.ind] = (0, current_height, rect.flipped)
256+
placed_so_far[rect.ind] = PackedRectangle(0, current_height, rect.flipped, True)
240257
current_height += rect.size[1]
241258
occupied.append((rect.size[0], current_height))
242259

243260
return PackedRectangles((total_width, current_height), placed_so_far)
261+
262+
263+
def pack_unique_rectangles(rectangles: List[Rectangle]) -> PackedRectangles:
264+
"""
265+
Naive rectangle packing in to a large rectangle. Flipping (i.e. rotating
266+
a rectangle by 90 degrees) is allowed. Inputs are deduplicated by their
267+
identifier.
268+
269+
This is a wrapper around pack_rectangles, where inputs come with an
270+
identifier. In particular, it calls pack_rectangles for the deduplicated inputs,
271+
then returns the values for all the inputs. The output for all rectangles with
272+
the same identifier will be the same, except that only the first one will have
273+
the is_first field True.
274+
275+
This is used to join several uv maps into a single scene, see
276+
TexturesUV.join_scene.
277+
278+
Args:
279+
rectangles: List of sizes of rectangles to pack
280+
281+
Returns:
282+
total_size: size of total large rectangle
283+
rectangles: location for each of the input rectangles.
284+
This includes whether they are flipped.
285+
The is_first field is true for the first rectangle
286+
with each identifier.
287+
"""
288+
289+
if len(rectangles) < 2:
290+
raise ValueError("Cannot pack less than two boxes")
291+
292+
input_map = {}
293+
input_indices: List[Tuple[int, bool]] = []
294+
unique_input_sizes: List[Tuple[int, int]] = []
295+
for rectangle in rectangles:
296+
if rectangle.identifier not in input_map:
297+
unique_index = len(unique_input_sizes)
298+
unique_input_sizes.append((rectangle.xsize, rectangle.ysize))
299+
input_map[rectangle.identifier] = unique_index
300+
input_indices.append((unique_index, True))
301+
else:
302+
unique_index = input_map[rectangle.identifier]
303+
input_indices.append((unique_index, False))
304+
305+
if len(unique_input_sizes) == 1:
306+
first = [PackedRectangle(0, 0, False, True)]
307+
rest = (len(rectangles) - 1) * [PackedRectangle(0, 0, False, False)]
308+
return PackedRectangles(unique_input_sizes[0], first + rest)
309+
310+
total_size, unique_locations = pack_rectangles(unique_input_sizes)
311+
full_locations = []
312+
for input_index, first in input_indices:
313+
full_locations.append(unique_locations[input_index]._replace(is_first=first))
314+
315+
return PackedRectangles(total_size, full_locations)

tests/test_render_meshes.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,9 @@ def test_join_uvs(self):
652652
verts_shifted2 = verts.clone()
653653
verts_shifted2 *= 0.5
654654
verts_shifted2[:, 1] -= 7
655+
verts_shifted3 = verts.clone()
656+
verts_shifted3 *= 0.5
657+
verts_shifted3[:, 1] -= 700
655658

656659
[faces] = plain_torus.faces_list()
657660
nocolor = torch.zeros((100, 100), device=device)
@@ -697,7 +700,11 @@ def test_join_uvs(self):
697700
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
698701
mesh2 = Meshes(verts=[verts_shifted1], faces=[faces], textures=textures2)
699702
mesh3 = Meshes(verts=[verts_shifted2], faces=[faces], textures=textures3)
700-
mesh = join_meshes_as_scene([mesh1, mesh2, mesh3])
703+
# mesh4 is like mesh1 but outside the field of view. It is here to test
704+
# that having another texture with the same map doesn't produce
705+
# two copies in the joined map.
706+
mesh4 = Meshes(verts=[verts_shifted3], faces=[faces], textures=textures1)
707+
mesh = join_meshes_as_scene([mesh1, mesh2, mesh3, mesh4])
701708

702709
output = renderer(mesh)[0, ..., :3].cpu()
703710
output1 = renderer(mesh1)[0, ..., :3].cpu()

tests/test_texturing.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
TexturesUV,
1313
TexturesVertex,
1414
_list_to_padded_wrapper,
15+
)
16+
from pytorch3d.renderer.mesh.utils import (
17+
Rectangle,
1518
pack_rectangles,
19+
pack_unique_rectangles,
1620
)
1721
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
1822
from test_meshes import init_mesh
@@ -873,21 +877,24 @@ def wrap_pack(self, sizes):
873877
mask = torch.zeros(total, dtype=torch.bool)
874878
seen_x_bound = False
875879
seen_y_bound = False
876-
for (in_x, in_y), loc in zip(sizes, res.locations):
877-
self.assertGreaterEqual(loc[0], 0)
878-
self.assertGreaterEqual(loc[1], 0)
879-
placed_x, placed_y = (in_y, in_x) if loc[2] else (in_x, in_y)
880-
upper_x = placed_x + loc[0]
881-
upper_y = placed_y + loc[1]
880+
for (in_x, in_y), (out_x, out_y, flipped, is_first) in zip(
881+
sizes, res.locations
882+
):
883+
self.assertTrue(is_first)
884+
self.assertGreaterEqual(out_x, 0)
885+
self.assertGreaterEqual(out_y, 0)
886+
placed_x, placed_y = (in_y, in_x) if flipped else (in_x, in_y)
887+
upper_x = placed_x + out_x
888+
upper_y = placed_y + out_y
882889
self.assertGreaterEqual(total[0], upper_x)
883890
if total[0] == upper_x:
884891
seen_x_bound = True
885892
self.assertGreaterEqual(total[1], upper_y)
886893
if total[1] == upper_y:
887894
seen_y_bound = True
888-
already_taken = torch.sum(mask[loc[0] : upper_x, loc[1] : upper_y])
895+
already_taken = torch.sum(mask[out_x:upper_x, out_y:upper_y])
889896
self.assertEqual(already_taken, 0)
890-
mask[loc[0] : upper_x, loc[1] : upper_y] = 1
897+
mask[out_x:upper_x, out_y:upper_y] = 1
891898
self.assertTrue(seen_x_bound)
892899
self.assertTrue(seen_y_bound)
893900

@@ -930,3 +937,29 @@ def test_random(self):
930937
for j in range(vals.shape[0]):
931938
sizes.append((int(vals[j, 0]), int(vals[j, 1])))
932939
self.wrap_pack(sizes)
940+
941+
def test_all_identical(self):
942+
sizes = [Rectangle(xsize=61, ysize=82, identifier=1729)] * 3
943+
total_size, locations = pack_unique_rectangles(sizes)
944+
self.assertEqual(total_size, (61, 82))
945+
self.assertEqual(len(locations), 3)
946+
for i, (x, y, is_flipped, is_first) in enumerate(locations):
947+
self.assertEqual(x, 0)
948+
self.assertEqual(y, 0)
949+
self.assertFalse(is_flipped)
950+
self.assertEqual(is_first, i == 0)
951+
952+
def test_one_different_id(self):
953+
sizes = [Rectangle(xsize=61, ysize=82, identifier=220)] * 3
954+
sizes.extend([Rectangle(xsize=61, ysize=82, identifier=284)] * 3)
955+
total_size, locations = pack_unique_rectangles(sizes)
956+
self.assertEqual(total_size, (82, 122))
957+
self.assertEqual(len(locations), 6)
958+
for i, (x, y, is_flipped, is_first) in enumerate(locations):
959+
self.assertTrue(is_flipped)
960+
self.assertEqual(is_first, i % 3 == 0)
961+
self.assertEqual(x, 0)
962+
if i < 3:
963+
self.assertEqual(y, 61)
964+
else:
965+
self.assertEqual(y, 0)

0 commit comments

Comments
 (0)