File tree Expand file tree Collapse file tree 2 files changed +25
-24
lines changed
py/torch_tensorrt/fx/converters Expand file tree Collapse file tree 2 files changed +25
-24
lines changed Original file line number Diff line number Diff line change @@ -543,3 +543,27 @@ def type_cast(
543
543
layer_i .set_output_type (0 , cast_type )
544
544
set_layer_name (layer_i , target , f"{ name } _dtype_change" )
545
545
return layer_i .get_output (0 )
546
+
547
+
548
+ def to_numpy (tensor : Optional [torch .Tensor ]) -> Optional [np .ndarray ]:
549
+ """
550
+ Convert a PyTorch Tensor to a Numpy Array. If the tensor is
551
+ quantized it will be dequantized first.
552
+
553
+ Args:
554
+ tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
555
+
556
+ Returns:
557
+ A Numpy array.
558
+ """
559
+
560
+ if tensor is None :
561
+ return tensor
562
+
563
+ assert isinstance (
564
+ tensor , torch .Tensor
565
+ ), f"to_numpy can only be called on None or a torch.Tensor, got: { tensor } "
566
+ if tensor .is_quantized :
567
+ tensor = tensor .dequantize ()
568
+
569
+ return tensor .cpu ().detach ().contiguous ().numpy ()
Original file line number Diff line number Diff line change 22
22
from .converter_utils import prepend_ones
23
23
from .converter_utils import has_dynamic_shape
24
24
from .converter_utils import get_shape_with_dynamic_shape
25
+ from .converter_utils import to_numpy
25
26
26
27
from ..types import (
27
28
Shape ,
@@ -278,30 +279,6 @@ def trunc_div(
278
279
return output
279
280
280
281
281
- def to_numpy (tensor : Optional [torch .Tensor ]) -> Optional [np .ndarray ]:
282
- """
283
- Convert a PyTorch Tensor to a Numpy Array. If the tensor is
284
- quantized it will be dequantized first.
285
-
286
- Args:
287
- tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
288
-
289
- Returns:
290
- A Numpy array.
291
- """
292
-
293
- if tensor is None :
294
- return tensor
295
-
296
- assert isinstance (
297
- tensor , torch .Tensor
298
- ), f"to_numpy can only be called on None or a torch.Tensor, got: { tensor } "
299
- if tensor .is_quantized :
300
- tensor = tensor .dequantize ()
301
-
302
- return tensor .cpu ().detach ().contiguous ().numpy ()
303
-
304
-
305
282
def trt_dtype_to_torch_dtype (trt_dtype ):
306
283
table = {
307
284
trt .bool : torch .bool ,
You can’t perform that action at this time.
0 commit comments