Skip to content

Commit 327bd2b

Browse files
gkioxarifacebook-github-bot
authored andcommitted
extend sample_points_from_meshes with texture
Summary: Enhanced `sample_points_from_meshes` with texture sampling * This new feature is used to return textures corresponding to the sampled points in `sample_points_from_meshes` Reviewed By: nikhilaravi Differential Revision: D24031525 fbshipit-source-id: 8e5d8f784cc38aa391aa8e84e54423bd9fad7ad1
1 parent 5c9485c commit 327bd2b

File tree

2 files changed

+220
-9
lines changed

2 files changed

+220
-9
lines changed

pytorch3d/ops/sample_points_from_meshes.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,19 @@
1111
import torch
1212
from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals
1313
from pytorch3d.ops.packed_to_padded import packed_to_padded
14+
from pytorch3d.renderer.mesh.rasterizer import Fragments as MeshFragments
1415

1516

1617
def sample_points_from_meshes(
17-
meshes, num_samples: int = 10000, return_normals: bool = False
18-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
18+
meshes,
19+
num_samples: int = 10000,
20+
return_normals: bool = False,
21+
return_textures: bool = False,
22+
) -> Union[
23+
torch.Tensor,
24+
Tuple[torch.Tensor, torch.Tensor],
25+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
26+
]:
1927
"""
2028
Convert a batch of meshes to a pointcloud by uniformly sampling points on
2129
the surface of the mesh with probability proportional to the face area.
@@ -24,10 +32,10 @@ def sample_points_from_meshes(
2432
meshes: A Meshes object with a batch of N meshes.
2533
num_samples: Integer giving the number of point samples per mesh.
2634
return_normals: If True, return normals for the sampled points.
27-
eps: (float) used to clamp the norm of the normals to avoid dividing by 0.
35+
return_textures: If True, return textures for the sampled points.
2836
2937
Returns:
30-
2-element tuple containing
38+
3-element tuple containing
3139
3240
- **samples**: FloatTensor of shape (N, num_samples, 3) giving the
3341
coordinates of sampled points for each mesh in the batch. For empty
@@ -36,13 +44,28 @@ def sample_points_from_meshes(
3644
to each sampled point. Only returned if return_normals is True.
3745
For empty meshes the corresponding row in the normals array will
3846
be filled with 0.
47+
- **textures**: FloatTensor of shape (N, num_samples, C) giving a C-dimensional
48+
texture vector to each sampled point. Only returned if return_textures is True.
49+
For empty meshes the corresponding row in the textures array will
50+
be filled with 0.
51+
52+
Note that in a future releases, we will replace the 3-element tuple output
53+
with a `Pointclouds` datastructure, as follows
54+
55+
.. code-block:: python
56+
57+
Poinclouds(samples, normals=normals, features=textures)
3958
"""
4059
if meshes.isempty():
4160
raise ValueError("Meshes are empty.")
4261

4362
verts = meshes.verts_packed()
4463
if not torch.isfinite(verts).all():
4564
raise ValueError("Meshes contain nan or inf.")
65+
66+
if return_textures and meshes.textures is None:
67+
raise ValueError("Meshes do not contain textures.")
68+
4669
faces = meshes.faces_packed()
4770
mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
4871
num_meshes = len(meshes)
@@ -66,7 +89,7 @@ def sample_points_from_meshes(
6689
sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
6790

6891
# Get the vertex coordinates of the sampled faces.
69-
face_verts = verts[faces.long()]
92+
face_verts = verts[faces]
7093
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
7194

7295
# Randomly generate barycentric coords.
@@ -92,9 +115,29 @@ def sample_points_from_meshes(
92115
vert_normals = vert_normals[sample_face_idxs]
93116
normals[meshes.valid] = vert_normals
94117

118+
if return_textures:
119+
# fragment data are of shape NxHxWxK. Here H=S, W=1 & K=1.
120+
pix_to_face = sample_face_idxs.view(len(meshes), num_samples, 1, 1) # NxSx1x1
121+
bary = torch.stack((w0, w1, w2), dim=2).unsqueeze(2).unsqueeze(2) # NxSx1x1x3
122+
# zbuf and dists are not used in `sample_textures` so we initialize them with dummy
123+
dummy = torch.zeros(
124+
(len(meshes), num_samples, 1, 1), device=meshes.device, dtype=torch.float32
125+
) # NxSx1x1
126+
fragments = MeshFragments(
127+
pix_to_face=pix_to_face, zbuf=dummy, bary_coords=bary, dists=dummy
128+
)
129+
textures = meshes.sample_textures(fragments) # NxSx1x1xC
130+
textures = textures[:, :, 0, 0, :] # NxSxC
131+
132+
# return
133+
# TODO(gkioxari) consider returning a Pointclouds instance [breaking]
134+
if return_normals and return_textures:
135+
return samples, normals, textures
136+
if return_normals: # return_textures is False
95137
return samples, normals
96-
else:
97-
return samples
138+
if return_textures: # return_normals is False
139+
return samples, textures
140+
return samples
98141

99142

100143
def _rand_barycentric_coords(

tests/test_sample_points_from_meshes.py

Lines changed: 170 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,31 @@
44
import unittest
55
from pathlib import Path
66

7+
import numpy as np
78
import torch
89
from common_testing import TestCaseMixin, get_random_cuda_device
10+
from PIL import Image
11+
from pytorch3d.io import load_objs_as_meshes
912
from pytorch3d.ops import sample_points_from_meshes
10-
from pytorch3d.structures.meshes import Meshes
13+
from pytorch3d.renderer import TexturesVertex
14+
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
15+
from pytorch3d.renderer.mesh.rasterize_meshes import barycentric_coordinates
16+
from pytorch3d.renderer.points import (
17+
NormWeightedCompositor,
18+
PointsRasterizationSettings,
19+
PointsRasterizer,
20+
PointsRenderer,
21+
)
22+
from pytorch3d.structures import Meshes, Pointclouds
1123
from pytorch3d.utils.ico_sphere import ico_sphere
1224

1325

26+
# If DEBUG=True, save out images generated in the tests for debugging.
27+
# All saved images have prefix DEBUG_
28+
DEBUG = False
29+
DATA_DIR = Path(__file__).resolve().parent / "data"
30+
31+
1432
class TestSamplePoints(TestCaseMixin, unittest.TestCase):
1533
def setUp(self) -> None:
1634
super().setUp()
@@ -22,18 +40,27 @@ def init_meshes(
2240
num_verts: int = 1000,
2341
num_faces: int = 3000,
2442
device: str = "cpu",
43+
add_texture: bool = False,
2544
):
2645
device = torch.device(device)
2746
verts_list = []
2847
faces_list = []
48+
texts_list = []
2949
for _ in range(num_meshes):
3050
verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
3151
faces = torch.randint(
3252
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
3353
)
54+
texts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
3455
verts_list.append(verts)
3556
faces_list.append(faces)
36-
meshes = Meshes(verts_list, faces_list)
57+
texts_list.append(texts)
58+
59+
# create textures
60+
textures = None
61+
if add_texture:
62+
textures = TexturesVertex(texts_list)
63+
meshes = Meshes(verts=verts_list, faces=faces_list, textures=textures)
3764

3865
return meshes
3966

@@ -264,6 +291,147 @@ def test_verts_nan(self):
264291
meshes, num_samples=100, return_normals=True
265292
)
266293

294+
def test_outputs(self):
295+
296+
for add_texture in (True, False):
297+
meshes = TestSamplePoints.init_meshes(
298+
device=torch.device("cuda:0"), add_texture=add_texture
299+
)
300+
out1 = sample_points_from_meshes(meshes, num_samples=100)
301+
self.assertTrue(torch.is_tensor(out1))
302+
303+
out2 = sample_points_from_meshes(
304+
meshes, num_samples=100, return_normals=True
305+
)
306+
self.assertTrue(isinstance(out2, tuple) and len(out2) == 2)
307+
308+
if add_texture:
309+
out3 = sample_points_from_meshes(
310+
meshes, num_samples=100, return_textures=True
311+
)
312+
self.assertTrue(isinstance(out3, tuple) and len(out3) == 2)
313+
314+
out4 = sample_points_from_meshes(
315+
meshes, num_samples=100, return_normals=True, return_textures=True
316+
)
317+
self.assertTrue(isinstance(out4, tuple) and len(out4) == 3)
318+
else:
319+
with self.assertRaisesRegex(
320+
ValueError, "Meshes do not contain textures."
321+
):
322+
sample_points_from_meshes(
323+
meshes, num_samples=100, return_textures=True
324+
)
325+
326+
with self.assertRaisesRegex(
327+
ValueError, "Meshes do not contain textures."
328+
):
329+
sample_points_from_meshes(
330+
meshes,
331+
num_samples=100,
332+
return_normals=True,
333+
return_textures=True,
334+
)
335+
336+
def test_texture_sampling(self):
337+
device = torch.device("cuda:0")
338+
batch_size = 6
339+
# verts
340+
verts = torch.rand((batch_size, 6, 3), device=device, dtype=torch.float32)
341+
verts[:, :3, 2] = 1.0
342+
verts[:, 3:, 2] = -1.0
343+
# textures
344+
texts = torch.rand((batch_size, 6, 3), device=device, dtype=torch.float32)
345+
# faces
346+
faces = torch.tensor([[0, 1, 2], [3, 4, 5]], device=device, dtype=torch.int64)
347+
faces = faces.view(1, 2, 3).expand(batch_size, -1, -1)
348+
349+
meshes = Meshes(verts=verts, faces=faces, textures=TexturesVertex(texts))
350+
351+
num_samples = 24
352+
samples, normals, textures = sample_points_from_meshes(
353+
meshes, num_samples=num_samples, return_normals=True, return_textures=True
354+
)
355+
356+
textures_naive = torch.zeros(
357+
(batch_size, num_samples, 3), dtype=torch.float32, device=device
358+
)
359+
for n in range(batch_size):
360+
for i in range(num_samples):
361+
p = samples[n, i]
362+
if p[2] > 0.0: # sampled from 1st face
363+
v0, v1, v2 = verts[n, 0, :2], verts[n, 1, :2], verts[n, 2, :2]
364+
w0, w1, w2 = barycentric_coordinates(p[:2], v0, v1, v2)
365+
t0, t1, t2 = texts[n, 0], texts[n, 1], texts[n, 2]
366+
else: # sampled from 2nd face
367+
v0, v1, v2 = verts[n, 3, :2], verts[n, 4, :2], verts[n, 5, :2]
368+
w0, w1, w2 = barycentric_coordinates(p[:2], v0, v1, v2)
369+
t0, t1, t2 = texts[n, 3], texts[n, 4], texts[n, 5]
370+
371+
tt = w0 * t0 + w1 * t1 + w2 * t2
372+
textures_naive[n, i] = tt
373+
374+
self.assertClose(textures, textures_naive)
375+
376+
def test_texture_sampling_cow(self):
377+
# test texture sampling for the cow example by converting
378+
# the cow mesh and its texture uv to a pointcloud with texture
379+
380+
device = torch.device("cuda:0")
381+
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
382+
obj_filename = obj_dir / "cow_mesh/cow.obj"
383+
384+
for text_type in ("uv", "atlas"):
385+
# Load mesh + texture
386+
if text_type == "uv":
387+
mesh = load_objs_as_meshes(
388+
[obj_filename], device=device, load_textures=True, texture_wrap=None
389+
)
390+
elif text_type == "atlas":
391+
mesh = load_objs_as_meshes(
392+
[obj_filename],
393+
device=device,
394+
load_textures=True,
395+
create_texture_atlas=True,
396+
texture_atlas_size=8,
397+
texture_wrap=None,
398+
)
399+
400+
points, normals, textures = sample_points_from_meshes(
401+
mesh, num_samples=50000, return_normals=True, return_textures=True
402+
)
403+
pointclouds = Pointclouds(points, normals=normals, features=textures)
404+
405+
for pos in ("front", "back"):
406+
# Init rasterizer settings
407+
if pos == "back":
408+
azim = 0.0
409+
elif pos == "front":
410+
azim = 180
411+
R, T = look_at_view_transform(2.7, 0, azim)
412+
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
413+
414+
raster_settings = PointsRasterizationSettings(
415+
image_size=512, radius=1e-2, points_per_pixel=1
416+
)
417+
418+
rasterizer = PointsRasterizer(
419+
cameras=cameras, raster_settings=raster_settings
420+
)
421+
compositor = NormWeightedCompositor()
422+
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
423+
images = renderer(pointclouds)
424+
425+
rgb = images[0, ..., :3].squeeze().cpu()
426+
if DEBUG:
427+
filename = "DEBUG_cow_mesh_to_pointcloud_%s_%s.png" % (
428+
text_type,
429+
pos,
430+
)
431+
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
432+
DATA_DIR / filename
433+
)
434+
267435
@staticmethod
268436
def sample_points_with_init(
269437
num_meshes: int,

0 commit comments

Comments
 (0)