Skip to content

Commit aa4cc0a

Browse files
bottlerfacebook-github-bot
authored andcommitted
images for debugging TexturesUV
Summary: New methods to directly plot a TexturesUV map with its used points, using PIL and matplotlib. Reviewed By: gkioxari Differential Revision: D23782968 fbshipit-source-id: 692970857b5be13a35a3175dc82ac03963a73555
1 parent b149bbf commit aa4cc0a

File tree

6 files changed

+192
-4
lines changed

6 files changed

+192
-4
lines changed

docs/tutorials/render_textured_meshes.ipynb

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
"\n",
9494
"# Data structures and functions for rendering\n",
9595
"from pytorch3d.structures import Meshes\n",
96-
"from pytorch3d.vis import AxisArgs, plot_batch_individually, plot_scene\n",
96+
"from pytorch3d.vis import AxisArgs, plot_batch_individually, plot_scene, texturesuv_image_matplotlib\n",
9797
"from pytorch3d.renderer import (\n",
9898
" look_at_view_transform,\n",
9999
" FoVPerspectiveCameras, \n",
@@ -236,8 +236,7 @@
236236
"obj_filename = os.path.join(DATA_DIR, \"cow_mesh/cow.obj\")\n",
237237
"\n",
238238
"# Load obj file\n",
239-
"mesh = load_objs_as_meshes([obj_filename], device=device)\n",
240-
"texture_image=mesh.textures.maps_padded()"
239+
"mesh = load_objs_as_meshes([obj_filename], device=device)"
241240
]
242241
},
243242
{
@@ -265,9 +264,29 @@
265264
"outputs": [],
266265
"source": [
267266
"plt.figure(figsize=(7,7))\n",
267+
"texture_image=mesh.textures.maps_padded()\n",
268268
"plt.imshow(texture_image.squeeze().cpu().numpy())\n",
269269
"plt.grid(\"off\");\n",
270-
"plt.axis('off');"
270+
"plt.axis(\"off\");"
271+
]
272+
},
273+
{
274+
"cell_type": "markdown",
275+
"metadata": {},
276+
"source": [
277+
"PyTorch3D has a built-in way to view the texture map with matplotlib along with the points on the map corresponding to vertices. There is also a method, texturesuv_image_PIL, to get a similar image which can be saved to a file."
278+
]
279+
},
280+
{
281+
"cell_type": "code",
282+
"execution_count": null,
283+
"metadata": {},
284+
"outputs": [],
285+
"source": [
286+
"plt.figure(figsize=(7,7))\n",
287+
"texturesuv_image_matplotlib(mesh.textures, subsample=None)\n",
288+
"plt.grid(\"off\");\n",
289+
"plt.axis(\"off\");"
271290
]
272291
},
273292
{

pytorch3d/renderer/mesh/textures.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,42 @@ def join_scene(self) -> "TexturesUV":
11741174
padding_mode=self.padding_mode,
11751175
)
11761176

1177+
def centers_for_image(self, index):
1178+
"""
1179+
Return the locations in the texture map which correspond to the given
1180+
verts_uvs, for one of the meshes. This is potentially useful for
1181+
visualizing the data. See the texturesuv_image_matplotlib and
1182+
texturesuv_image_PIL functions.
1183+
1184+
Args:
1185+
index: batch index of the mesh whose centers to return.
1186+
1187+
Returns:
1188+
centers: coordinates of points in the texture image
1189+
- a FloatTensor of shape (V,2)
1190+
"""
1191+
if self._N != 1:
1192+
raise ValueError(
1193+
"This function only supports plotting textures for one mesh."
1194+
)
1195+
texture_image = self.maps_padded()
1196+
verts_uvs = self.verts_uvs_list()[index][None]
1197+
_, H, W, _3 = texture_image.shape
1198+
coord1 = torch.arange(W).expand(H, W)
1199+
coord2 = torch.arange(H)[:, None].expand(H, W)
1200+
coords = torch.stack([coord1, coord2])[None]
1201+
with torch.no_grad():
1202+
# Get xy cartesian coordinates based on the uv coordinates
1203+
centers = F.grid_sample(
1204+
torch.flip(coords.to(texture_image), [2]),
1205+
# Convert from [0, 1] -> [-1, 1] range expected by grid sample
1206+
verts_uvs[:, None] * 2.0 - 1,
1207+
align_corners=self.align_corners,
1208+
padding_mode=self.padding_mode,
1209+
).cpu()
1210+
centers = centers[0, :, 0].T
1211+
return centers
1212+
11771213

11781214
class TexturesVertex(TexturesBase):
11791215
def __init__(

pytorch3d/vis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33
from .plotly_vis import AxisArgs, Lighting, plot_batch_individually, plot_scene
4+
from .texture_vis import texturesuv_image_matplotlib, texturesuv_image_PIL
45

56

67
__all__ = [k for k in globals().keys() if not k.startswith("_")]

pytorch3d/vis/texture_vis.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
from typing import Optional
3+
4+
import numpy as np
5+
from PIL import Image, ImageDraw
6+
from pytorch3d.renderer.mesh import TexturesUV
7+
8+
9+
def texturesuv_image_matplotlib(
10+
texture: TexturesUV,
11+
*,
12+
texture_index: int = 0,
13+
radius: float = 1,
14+
color=(1.0, 0.0, 0.0),
15+
subsample: Optional[int] = 10000,
16+
origin: str = "upper",
17+
):
18+
"""
19+
Plot the texture image for one element of a TexturesUV with
20+
matplotlib together with verts_uvs positions circled.
21+
In particular a value in verts_uvs which is never referenced
22+
in faces_uvs will still be plotted.
23+
This is for debugging purposes, e.g. to align the map with
24+
the uv coordinates. In particular, matplotlib
25+
is used which is not an official dependency of PyTorch3D.
26+
27+
Args:
28+
texture: a TexturesUV object with one mesh
29+
texture_index: index in the batch to plot
30+
radius: plotted circle radius in pixels
31+
color: any matplotlib-understood color for the circles.
32+
subsample: if not None, number of points to plot.
33+
Otherwise all points are plotted.
34+
origin: "upper" or "lower" like matplotlib.imshow
35+
"""
36+
37+
import matplotlib.pyplot as plt
38+
from matplotlib.patches import Circle
39+
40+
texture_image = texture.maps_padded()
41+
centers = texture.centers_for_image(index=texture_index).numpy()
42+
43+
ax = plt.gca()
44+
ax.imshow(texture_image[texture_index].detach().cpu().numpy(), origin=origin)
45+
46+
n_points = centers.shape[0]
47+
if subsample is None or n_points <= subsample:
48+
indices = range(n_points)
49+
else:
50+
indices = np.random.choice(n_points, subsample, replace=False)
51+
for i in indices:
52+
# setting clip_on=False makes it obvious when
53+
# we have UV coordinates outside the correct range
54+
ax.add_patch(Circle(centers[i], radius, color=color, clip_on=False))
55+
56+
57+
def texturesuv_image_PIL(
58+
texture: TexturesUV,
59+
*,
60+
texture_index: int = 0,
61+
radius: float = 1,
62+
color="red",
63+
subsample: Optional[int] = 10000,
64+
):
65+
"""
66+
Return a PIL image of the texture image of one element of the batch
67+
from a TexturesUV, together with the verts_uvs positions circled.
68+
In particular a value in verts_uvs which is never referenced
69+
in faces_uvs will still be plotted.
70+
This is for debugging purposes, e.g. to align the map with
71+
the uv coordinates. In particular, matplotlib
72+
is used which is not an official dependency of PyTorch3D.
73+
74+
Args:
75+
texture: a TexturesUV object with one mesh
76+
texture_index: index in the batch to plot
77+
radius: plotted circle radius in pixels
78+
color: any PIL-understood color for the circles.
79+
subsample: if not None, number of points to plot.
80+
Otherwise all points are plotted.
81+
82+
Returns:
83+
PIL Image object.
84+
"""
85+
86+
centers = texture.centers_for_image(index=texture_index).numpy()
87+
texture_image = texture.maps_padded()
88+
texture_array = (texture_image[texture_index] * 255).cpu().numpy().astype(np.uint8)
89+
90+
image = Image.fromarray(texture_array)
91+
draw = ImageDraw.Draw(image)
92+
93+
n_points = centers.shape[0]
94+
if subsample is None or n_points <= subsample:
95+
indices = range(n_points)
96+
else:
97+
indices = np.random.choice(n_points, subsample, replace=False)
98+
99+
for i in indices:
100+
x = centers[i][0]
101+
y = centers[i][1]
102+
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], fill=color)
103+
104+
return image

tests/data/texturesuv_debug.png

94.7 KB
Loading

tests/test_texturing.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33

44
import unittest
5+
from pathlib import Path
56

7+
import numpy as np
68
import torch
79
import torch.nn.functional as F
810
from common_testing import TestCaseMixin
11+
from PIL import Image
912
from pytorch3d.renderer.mesh.rasterizer import Fragments
1013
from pytorch3d.renderer.mesh.textures import (
1114
TexturesAtlas,
@@ -15,9 +18,14 @@
1518
pack_rectangles,
1619
)
1720
from pytorch3d.structures import Meshes, list_to_packed, packed_to_list
21+
from pytorch3d.vis import texturesuv_image_PIL
1822
from test_meshes import TestMeshes
1923

2024

25+
DEBUG = False
26+
DATA_DIR = Path(__file__).resolve().parent / "data"
27+
28+
2129
def tryindex(self, index, tex, meshes, source):
2230
tex2 = tex[index]
2331
meshes2 = meshes[index]
@@ -471,6 +479,10 @@ def test_getitem(self):
471479

472480

473481
class TestTexturesUV(TestCaseMixin, unittest.TestCase):
482+
def setUp(self) -> None:
483+
super().setUp()
484+
torch.manual_seed(42)
485+
474486
def test_sample_textures_uv(self):
475487
barycentric_coords = torch.tensor(
476488
[[0.5, 0.3, 0.2], [0.3, 0.6, 0.1]], dtype=torch.float32
@@ -821,6 +833,22 @@ def test_getitem(self):
821833
tryindex(self, index, tex, meshes, source)
822834
tryindex(self, [2, 4], tex, meshes, source)
823835

836+
def test_png_debug(self):
837+
maps = torch.rand(size=(1, 256, 128, 3)) * torch.tensor([0.8, 1, 0.8])
838+
verts_uvs = torch.rand(size=(1, 20, 2))
839+
faces_uvs = torch.zeros(size=(1, 0, 3), dtype=torch.int64)
840+
tex = TexturesUV(maps=maps, faces_uvs=faces_uvs, verts_uvs=verts_uvs)
841+
842+
image = texturesuv_image_PIL(tex, radius=3)
843+
image_out = np.array(image)
844+
if DEBUG:
845+
image.save(DATA_DIR / "texturesuv_debug_.png")
846+
847+
with Image.open(DATA_DIR / "texturesuv_debug.png") as image_ref_file:
848+
image_ref = np.array(image_ref_file)
849+
850+
self.assertClose(image_out, image_ref)
851+
824852

825853
class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
826854
def setUp(self) -> None:

0 commit comments

Comments
 (0)