Skip to content

Commit 2bcf5f4

Browse files
committed
Move to_numpy implementation to converter_util
1 parent 79286e7 commit 2bcf5f4

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,3 +543,27 @@ def type_cast(
543543
layer_i.set_output_type(0, cast_type)
544544
set_layer_name(layer_i, target, f"{name}_dtype_change")
545545
return layer_i.get_output(0)
546+
547+
548+
def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]:
549+
"""
550+
Convert a PyTorch Tensor to a Numpy Array. If the tensor is
551+
quantized it will be dequantized first.
552+
553+
Args:
554+
tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
555+
556+
Returns:
557+
A Numpy array.
558+
"""
559+
560+
if tensor is None:
561+
return tensor
562+
563+
assert isinstance(
564+
tensor, torch.Tensor
565+
), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}"
566+
if tensor.is_quantized:
567+
tensor = tensor.dequantize()
568+
569+
return tensor.cpu().detach().contiguous().numpy()

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .converter_utils import prepend_ones
2323
from .converter_utils import has_dynamic_shape
2424
from .converter_utils import get_shape_with_dynamic_shape
25+
from .converter_utils import to_numpy
2526

2627
from ..types import (
2728
Shape,
@@ -278,30 +279,6 @@ def trunc_div(
278279
return output
279280

280281

281-
def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]:
282-
"""
283-
Convert a PyTorch Tensor to a Numpy Array. If the tensor is
284-
quantized it will be dequantized first.
285-
286-
Args:
287-
tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
288-
289-
Returns:
290-
A Numpy array.
291-
"""
292-
293-
if tensor is None:
294-
return tensor
295-
296-
assert isinstance(
297-
tensor, torch.Tensor
298-
), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}"
299-
if tensor.is_quantized:
300-
tensor = tensor.dequantize()
301-
302-
return tensor.cpu().detach().contiguous().numpy()
303-
304-
305282
def trt_dtype_to_torch_dtype(trt_dtype):
306283
table = {
307284
trt.bool: torch.bool,

0 commit comments

Comments
 (0)