Skip to content

Commit 44508ed

Browse files
patricklabatutfacebook-github-bot
authored andcommitted
Make Transform3d.to() not ignore dtype
Summary: Make Transform3d.to() not ignore a different dtype when device is the same and no copy is requested. Fix other methods where dtype is ignored. Reviewed By: nikhilaravi Differential Revision: D28981171 fbshipit-source-id: 4528e6092f4a693aecbe8131ede985fca84e84cf
1 parent 626bf3f commit 44508ed

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

pytorch3d/transforms/transform3d.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,15 @@ def __init__(
162162
raise ValueError(
163163
'"matrix" has to be a tensor of shape (minibatch, 4, 4)'
164164
)
165-
# set the device from matrix
165+
# set dtype and device from matrix
166+
dtype = matrix.dtype
166167
device = matrix.device
167168
self._matrix = matrix.view(-1, 4, 4)
168169

169170
self._transforms = [] # store transforms to compose
170171
self._lu = None
171172
self.device = make_device(device)
173+
self.dtype = dtype
172174

173175
def __len__(self):
174176
return self.get_matrix().shape[0]
@@ -200,7 +202,7 @@ def compose(self, *others):
200202
Returns:
201203
A new Transform3d with the stored transforms
202204
"""
203-
out = Transform3d(device=self.device)
205+
out = Transform3d(dtype=self.dtype, device=self.device)
204206
out._matrix = self._matrix.clone()
205207
for other in others:
206208
if not isinstance(other, Transform3d):
@@ -259,7 +261,7 @@ def inverse(self, invert_composed: bool = False):
259261
transformation.
260262
"""
261263

262-
tinv = Transform3d(device=self.device)
264+
tinv = Transform3d(dtype=self.dtype, device=self.device)
263265

264266
if invert_composed:
265267
# first compose then invert
@@ -278,7 +280,7 @@ def inverse(self, invert_composed: bool = False):
278280
# right-multiplies by the inverse of self._matrix
279281
# at the end of the composition.
280282
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
281-
last = Transform3d(device=self.device)
283+
last = Transform3d(dtype=self.dtype, device=self.device)
282284
last._matrix = i_matrix
283285
tinv._transforms.append(last)
284286
else:
@@ -291,7 +293,7 @@ def inverse(self, invert_composed: bool = False):
291293
def stack(self, *others):
292294
transforms = [self] + list(others)
293295
matrix = torch.cat([t._matrix for t in transforms], dim=0)
294-
out = Transform3d()
296+
out = Transform3d(dtype=self.dtype, device=self.device)
295297
out._matrix = matrix
296298
return out
297299

@@ -392,7 +394,7 @@ def clone(self):
392394
Returns:
393395
new Transforms object.
394396
"""
395-
other = Transform3d(device=self.device)
397+
other = Transform3d(dtype=self.dtype, device=self.device)
396398
if self._lu is not None:
397399
other._lu = [elem.clone() for elem in self._lu]
398400
other._matrix = self._matrix.clone()
@@ -422,17 +424,22 @@ def to(
422424
Transform3d object.
423425
"""
424426
device_ = make_device(device)
425-
if not copy and self.device == device_:
427+
dtype_ = self.dtype if dtype is None else dtype
428+
skip_to = self.device == device_ and self.dtype == dtype_
429+
430+
if not copy and skip_to:
426431
return self
427432

428433
other = self.clone()
429-
if self.device == device_:
434+
435+
if skip_to:
430436
return other
431437

432438
other.device = device_
433-
other._matrix = self._matrix.to(device=device_, dtype=dtype)
439+
other.dtype = dtype_
440+
other._matrix = other._matrix.to(device=device_, dtype=dtype_)
434441
other._transforms = [
435-
t.to(device_, copy=copy, dtype=dtype) for t in other._transforms
442+
t.to(device_, copy=copy, dtype=dtype_) for t in other._transforms
436443
]
437444
return other
438445

tests/test_transforms.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,45 @@ def test_to(self):
2828
cpu_t = t.to("cpu")
2929
self.assertEqual(cpu_device, cpu_t.device)
3030
self.assertEqual(cpu_device, t.device)
31+
self.assertEqual(torch.float32, cpu_t.dtype)
32+
self.assertEqual(torch.float32, t.dtype)
3133
self.assertIs(t, cpu_t)
3234

3335
cpu_t = t.to(cpu_device)
3436
self.assertEqual(cpu_device, cpu_t.device)
3537
self.assertEqual(cpu_device, t.device)
38+
self.assertEqual(torch.float32, cpu_t.dtype)
39+
self.assertEqual(torch.float32, t.dtype)
3640
self.assertIs(t, cpu_t)
3741

42+
cpu_t = t.to(dtype=torch.float64, device=cpu_device)
43+
self.assertEqual(cpu_device, cpu_t.device)
44+
self.assertEqual(cpu_device, t.device)
45+
self.assertEqual(torch.float64, cpu_t.dtype)
46+
self.assertEqual(torch.float32, t.dtype)
47+
self.assertIsNot(t, cpu_t)
48+
3849
cuda_device = torch.device("cuda")
3950

4051
cuda_t = t.to("cuda")
4152
self.assertEqual(cuda_device, cuda_t.device)
4253
self.assertEqual(cpu_device, t.device)
54+
self.assertEqual(torch.float32, cuda_t.dtype)
55+
self.assertEqual(torch.float32, t.dtype)
4356
self.assertIsNot(t, cuda_t)
4457

4558
cuda_t = t.to(cuda_device)
4659
self.assertEqual(cuda_device, cuda_t.device)
4760
self.assertEqual(cpu_device, t.device)
61+
self.assertEqual(torch.float32, cuda_t.dtype)
62+
self.assertEqual(torch.float32, t.dtype)
63+
self.assertIsNot(t, cuda_t)
64+
65+
cuda_t = t.to(dtype=torch.float64, device=cuda_device)
66+
self.assertEqual(cuda_device, cuda_t.device)
67+
self.assertEqual(cpu_device, t.device)
68+
self.assertEqual(torch.float64, cuda_t.dtype)
69+
self.assertEqual(torch.float32, t.dtype)
4870
self.assertIsNot(t, cuda_t)
4971

5072
cpu_points = torch.rand(9, 3)

0 commit comments

Comments
 (0)