@@ -162,13 +162,15 @@ def __init__(
162
162
raise ValueError (
163
163
'"matrix" has to be a tensor of shape (minibatch, 4, 4)'
164
164
)
165
- # set the device from matrix
165
+ # set dtype and device from matrix
166
+ dtype = matrix .dtype
166
167
device = matrix .device
167
168
self ._matrix = matrix .view (- 1 , 4 , 4 )
168
169
169
170
self ._transforms = [] # store transforms to compose
170
171
self ._lu = None
171
172
self .device = make_device (device )
173
+ self .dtype = dtype
172
174
173
175
def __len__ (self ):
174
176
return self .get_matrix ().shape [0 ]
@@ -200,7 +202,7 @@ def compose(self, *others):
200
202
Returns:
201
203
A new Transform3d with the stored transforms
202
204
"""
203
- out = Transform3d (device = self .device )
205
+ out = Transform3d (dtype = self . dtype , device = self .device )
204
206
out ._matrix = self ._matrix .clone ()
205
207
for other in others :
206
208
if not isinstance (other , Transform3d ):
@@ -259,7 +261,7 @@ def inverse(self, invert_composed: bool = False):
259
261
transformation.
260
262
"""
261
263
262
- tinv = Transform3d (device = self .device )
264
+ tinv = Transform3d (dtype = self . dtype , device = self .device )
263
265
264
266
if invert_composed :
265
267
# first compose then invert
@@ -278,7 +280,7 @@ def inverse(self, invert_composed: bool = False):
278
280
# right-multiplies by the inverse of self._matrix
279
281
# at the end of the composition.
280
282
tinv ._transforms = [t .inverse () for t in reversed (self ._transforms )]
281
- last = Transform3d (device = self .device )
283
+ last = Transform3d (dtype = self . dtype , device = self .device )
282
284
last ._matrix = i_matrix
283
285
tinv ._transforms .append (last )
284
286
else :
@@ -291,7 +293,7 @@ def inverse(self, invert_composed: bool = False):
291
293
def stack (self , * others ):
292
294
transforms = [self ] + list (others )
293
295
matrix = torch .cat ([t ._matrix for t in transforms ], dim = 0 )
294
- out = Transform3d ()
296
+ out = Transform3d (dtype = self . dtype , device = self . device )
295
297
out ._matrix = matrix
296
298
return out
297
299
@@ -392,7 +394,7 @@ def clone(self):
392
394
Returns:
393
395
new Transforms object.
394
396
"""
395
- other = Transform3d (device = self .device )
397
+ other = Transform3d (dtype = self . dtype , device = self .device )
396
398
if self ._lu is not None :
397
399
other ._lu = [elem .clone () for elem in self ._lu ]
398
400
other ._matrix = self ._matrix .clone ()
@@ -422,17 +424,22 @@ def to(
422
424
Transform3d object.
423
425
"""
424
426
device_ = make_device (device )
425
- if not copy and self .device == device_ :
427
+ dtype_ = self .dtype if dtype is None else dtype
428
+ skip_to = self .device == device_ and self .dtype == dtype_
429
+
430
+ if not copy and skip_to :
426
431
return self
427
432
428
433
other = self .clone ()
429
- if self .device == device_ :
434
+
435
+ if skip_to :
430
436
return other
431
437
432
438
other .device = device_
433
- other ._matrix = self ._matrix .to (device = device_ , dtype = dtype )
439
+ other .dtype = dtype_
440
+ other ._matrix = other ._matrix .to (device = device_ , dtype = dtype_ )
434
441
other ._transforms = [
435
- t .to (device_ , copy = copy , dtype = dtype ) for t in other ._transforms
442
+ t .to (device_ , copy = copy , dtype = dtype_ ) for t in other ._transforms
436
443
]
437
444
return other
438
445
0 commit comments