Skip to content

Commit 569e522

Browse files
Randlfacebook-github-bot
authored andcommitted
Add check for verts and faces being on same device and also checks for pointclouds/features/normals being on the same device (#384)
Summary: Pull Request resolved: #384 Test Plan: `test_meshes` and `test_points` Reviewed By: gkioxari Differential Revision: D24730524 Pulled By: nikhilaravi fbshipit-source-id: acbd35be5d9f1b13b4d56f3db14f6e8c2c0f7596
1 parent 1934046 commit 569e522

File tree

4 files changed

+85
-2
lines changed

4 files changed

+85
-2
lines changed

pytorch3d/structures/meshes.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,13 @@ def __init__(self, verts=None, faces=None, textures=None):
325325
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
326326
if self._N > 0:
327327
self.device = self._verts_list[0].device
328+
if not (
329+
all(v.device == self.device for v in verts)
330+
and all(f.device == self.device for f in faces)
331+
):
332+
raise ValueError(
333+
"All Verts and Faces tensors should be on same device."
334+
)
328335
self._num_verts_per_mesh = torch.tensor(
329336
[len(v) for v in self._verts_list], device=self.device
330337
)
@@ -341,7 +348,6 @@ def __init__(self, verts=None, faces=None, textures=None):
341348
dtype=torch.bool,
342349
device=self.device,
343350
)
344-
345351
if (len(self._num_verts_per_mesh.unique()) == 1) and (
346352
len(self._num_faces_per_mesh.unique()) == 1
347353
):
@@ -355,6 +361,10 @@ def __init__(self, verts=None, faces=None, textures=None):
355361
self._N = self._verts_padded.shape[0]
356362
self._V = self._verts_padded.shape[1]
357363

364+
if verts.device != faces.device:
365+
msg = "Verts and Faces tensors should be on same device. \n Got {} and {}."
366+
raise ValueError(msg.format(verts.device, faces.device))
367+
358368
self.device = self._verts_padded.device
359369
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
360370
if self._N > 0:

pytorch3d/structures/pointclouds.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,13 @@ def __init__(self, points, normals=None, features=None):
180180
self._num_points_per_cloud = []
181181

182182
if self._N > 0:
183+
self.device = self._points_list[0].device
183184
for p in self._points_list:
184185
if len(p) > 0 and (p.dim() != 2 or p.shape[1] != 3):
185186
raise ValueError("Clouds in list must be of shape Px3 or empty")
187+
if p.device != self.device:
188+
raise ValueError("All points must be on the same device")
186189

187-
self.device = self._points_list[0].device
188190
num_points_per_cloud = torch.tensor(
189191
[len(p) for p in self._points_list], device=self.device
190192
)
@@ -261,6 +263,10 @@ def _parse_auxiliary_input(self, aux_input):
261263
raise ValueError(
262264
"A cloud has mismatched numbers of points and inputs"
263265
)
266+
if d.device != self.device:
267+
raise ValueError(
268+
"All auxillary inputs must be on the same device as the points."
269+
)
264270
if p > 0:
265271
if d.dim() != 2:
266272
raise ValueError(
@@ -283,6 +289,10 @@ def _parse_auxiliary_input(self, aux_input):
283289
"Inputs tensor must have the right maximum \
284290
number of points in each cloud."
285291
)
292+
if aux_input.device != self.device:
293+
raise ValueError(
294+
"All auxillary inputs must be on the same device as the points."
295+
)
286296
aux_input_C = aux_input.shape[2]
287297
return None, aux_input, aux_input_C
288298
else:

tests/test_meshes.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
import random
34
import unittest
45

56
import numpy as np
@@ -162,6 +163,29 @@ def test_simple(self):
162163
torch.tensor([0, 3, 8], dtype=torch.int64),
163164
)
164165

166+
def test_init_error(self):
167+
# Check if correct errors are raised when verts/faces are on
168+
# different devices
169+
170+
mesh = TestMeshes.init_mesh(10, 10, 100)
171+
verts_list = mesh.verts_list() # all tensors on cpu
172+
verts_list = [
173+
v.to("cuda:0") if random.uniform(0, 1) > 0.5 else v for v in verts_list
174+
]
175+
faces_list = mesh.faces_list()
176+
177+
with self.assertRaises(ValueError) as cm:
178+
Meshes(verts=verts_list, faces=faces_list)
179+
self.assertTrue("same device" in cm.msg)
180+
181+
verts_padded = mesh.verts_padded() # on cpu
182+
verts_padded = verts_padded.to("cuda:0")
183+
faces_padded = mesh.faces_padded()
184+
185+
with self.assertRaises(ValueError) as cm:
186+
Meshes(verts=verts_padded, faces=faces_padded)
187+
self.assertTrue("same device" in cm.msg)
188+
165189
def test_simple_random_meshes(self):
166190

167191
# Define the test mesh object either as a list or tensor of faces/verts.

tests/test_pointclouds.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33

4+
import random
45
import unittest
56

67
import numpy as np
@@ -126,6 +127,44 @@ def test_simple(self):
126127
torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]),
127128
)
128129

130+
def test_init_error(self):
131+
# Check if correct errors are raised when verts/faces are on
132+
# different devices
133+
134+
clouds = self.init_cloud(10, 100, 5)
135+
points_list = clouds.points_list() # all tensors on cuda:0
136+
points_list = [
137+
p.to("cpu") if random.uniform(0, 1) > 0.5 else p for p in points_list
138+
]
139+
features_list = clouds.features_list()
140+
normals_list = clouds.normals_list()
141+
142+
with self.assertRaises(ValueError) as cm:
143+
Pointclouds(
144+
points=points_list, features=features_list, normals=normals_list
145+
)
146+
self.assertTrue("same device" in cm.msg)
147+
148+
points_list = clouds.points_list()
149+
features_list = [
150+
f.to("cpu") if random.uniform(0, 1) > 0.2 else f for f in features_list
151+
]
152+
with self.assertRaises(ValueError) as cm:
153+
Pointclouds(
154+
points=points_list, features=features_list, normals=normals_list
155+
)
156+
self.assertTrue("same device" in cm.msg)
157+
158+
points_padded = clouds.points_padded() # on cuda:0
159+
features_padded = clouds.features_padded().to("cpu")
160+
normals_padded = clouds.normals_padded()
161+
162+
with self.assertRaises(ValueError) as cm:
163+
Pointclouds(
164+
points=points_padded, features=features_padded, normals=normals_padded
165+
)
166+
self.assertTrue("same device" in cm.msg)
167+
129168
def test_all_constructions(self):
130169
public_getters = [
131170
"points_list",

0 commit comments

Comments
 (0)