Skip to content

Commit 79286e7

Browse files
committed
Moving funcs to_numpy and trt_dtype_to_torch_dtype from converter_util to operator
1 parent 979ab42 commit 79286e7

File tree

2 files changed

+34
-34
lines changed

2 files changed

+34
-34
lines changed

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -120,30 +120,6 @@ def extend_mod_attr_to_tuple(mod: torch.nn.Module, name: str, size: int):
120120
return extend_attr_to_tuple(val, size)
121121

122122

123-
def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]:
124-
"""
125-
Convert a PyTorch Tensor to a Numpy Array. If the tensor is
126-
quantized it will be dequantized first.
127-
128-
Args:
129-
tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
130-
131-
Returns:
132-
A Numpy array.
133-
"""
134-
135-
if tensor is None:
136-
return tensor
137-
138-
assert isinstance(
139-
tensor, torch.Tensor
140-
), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}"
141-
if tensor.is_quantized:
142-
tensor = tensor.dequantize()
143-
144-
return tensor.cpu().detach().contiguous().numpy()
145-
146-
147123
def has_dynamic_shape(shape: Shape) -> bool:
148124
"""
149125
Determine if the given shape has dynamic dim. i.e. if there're -1 in shape.
@@ -567,13 +543,3 @@ def type_cast(
567543
layer_i.set_output_type(0, cast_type)
568544
set_layer_name(layer_i, target, f"{name}_dtype_change")
569545
return layer_i.get_output(0)
570-
571-
572-
def trt_dtype_to_torch_dtype(trt_dtype):
573-
table = {
574-
trt.bool: torch.bool,
575-
trt.int32: torch.int32,
576-
trt.float16: torch.float16,
577-
trt.float32: torch.float32,
578-
}
579-
return table[trt_dtype]

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,40 @@ def trunc_div(
278278
return output
279279

280280

281+
def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]:
282+
"""
283+
Convert a PyTorch Tensor to a Numpy Array. If the tensor is
284+
quantized it will be dequantized first.
285+
286+
Args:
287+
tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
288+
289+
Returns:
290+
A Numpy array.
291+
"""
292+
293+
if tensor is None:
294+
return tensor
295+
296+
assert isinstance(
297+
tensor, torch.Tensor
298+
), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}"
299+
if tensor.is_quantized:
300+
tensor = tensor.dequantize()
301+
302+
return tensor.cpu().detach().contiguous().numpy()
303+
304+
305+
def trt_dtype_to_torch_dtype(trt_dtype):
306+
table = {
307+
trt.bool: torch.bool,
308+
trt.int32: torch.int32,
309+
trt.float16: torch.float16,
310+
trt.float32: torch.float32,
311+
}
312+
return table[trt_dtype]
313+
314+
281315
def add_tile(network, target, kwargs, name):
282316
input_t = kwargs["input"]
283317
input_val = get_trt_tensor(network, input_t, f"{name}_input")

0 commit comments

Comments
 (0)