Skip to content

Commit 8dd9a33

Browse files
committed
update type conversion
1 parent 25c1ff4 commit 8dd9a33

File tree

2 files changed

+14
-29
lines changed

2 files changed

+14
-29
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,7 @@ def cast_int_int_div_trt_tensor(
124124
Returns:
125125
A list of lhs_val and rhs_val casted to the approriate datatype
126126
"""
127-
if (lhs_val.dtype == trt.int8 or lhs_val.dtype == trt.int32) and (
128-
rhs_val.dtype == trt.int8 or rhs_val.dtype == trt.int32
129-
):
127+
if lhs_val.dtype == trt.int32 and rhs_val.dtype == trt.int32:
130128
lhs_val = cast_trt_tensor(network, lhs_val, trt.float32, name)
131129
rhs_val = cast_trt_tensor(network, rhs_val, trt.float32, name)
132130
return [lhs_val, rhs_val]
@@ -190,14 +188,7 @@ def extend_attr_to_tuple(
190188
return val
191189

192190

193-
def trt_cast_int_to_float(network: TRTNetwork, name: str, tensor: TRTTensor):
194-
if tensor.dtype == trt.int8 or tensor.dtype == trt.int32:
195-
return cast_trt_tensor(network, tensor, trt.float32, name)
196-
197-
return tensor
198-
199-
200-
def trt_cast_int_or_float_to_bool(network: TRTNetwork, name: str, tensor: TRTTensor):
191+
def cast_int_or_float_to_bool(network: TRTNetwork, name: str, tensor: TRTTensor):
201192
if tensor.dtype != trt.bool:
202193
return cast_trt_tensor(network, tensor, trt.bool, name)
203194

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from torch.fx.node import Target
66
from torch_tensorrt.dynamo._SourceIR import SourceIR
77
from torch_tensorrt.dynamo.conversion.converter_utils import (
8-
trt_cast_int_or_float_to_bool,
9-
trt_cast_int_to_float,
8+
cast_int_int_div_trt_tensor,
9+
cast_int_or_float_to_bool,
1010
)
1111
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1212
convert_binary_elementwise,
@@ -324,11 +324,8 @@ def div(
324324
lhs_val: Union[TRTTensor, int, float],
325325
rhs_val: Union[TRTTensor, int, float],
326326
) -> TRTTensor:
327-
if isinstance(lhs_val, TRTTensor):
328-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
329-
330-
if isinstance(rhs_val, TRTTensor):
331-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
327+
if isinstance(lhs_val, TRTTensor) and isinstance(rhs_val, TRTTensor):
328+
lhs_val, rhs_val = cast_int_int_div_trt_tensor(network, lhs_val, rhs_val, name)
332329

333330
return convert_binary_elementwise(
334331
network, target, source_ir, name, trt.ElementWiseOperation.DIV, lhs_val, rhs_val
@@ -343,11 +340,8 @@ def pow(
343340
lhs_val: Union[TRTTensor, int, float],
344341
rhs_val: Union[TRTTensor, int, float],
345342
) -> TRTTensor:
346-
if isinstance(lhs_val, TRTTensor):
347-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
348-
349-
if isinstance(rhs_val, TRTTensor):
350-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
343+
if isinstance(lhs_val, TRTTensor) and isinstance(rhs_val, TRTTensor):
344+
lhs_val, rhs_val = cast_int_int_div_trt_tensor(network, lhs_val, rhs_val, name)
351345

352346
return convert_binary_elementwise(
353347
network, target, source_ir, name, trt.ElementWiseOperation.POW, lhs_val, rhs_val
@@ -382,10 +376,10 @@ def logical_and(
382376
rhs_val: Union[TRTTensor, int, float],
383377
) -> TRTTensor:
384378
if isinstance(lhs_val, TRTTensor):
385-
lhs_val = trt_cast_int_or_float_to_bool(network, name, lhs_val)
379+
lhs_val = cast_int_or_float_to_bool(network, name, lhs_val)
386380

387381
if isinstance(rhs_val, TRTTensor):
388-
rhs_val = trt_cast_int_or_float_to_bool(network, name, rhs_val)
382+
rhs_val = cast_int_or_float_to_bool(network, name, rhs_val)
389383

390384
return convert_binary_elementwise(
391385
network, target, source_ir, name, trt.ElementWiseOperation.AND, lhs_val, rhs_val
@@ -401,10 +395,10 @@ def logical_or(
401395
rhs_val: Union[TRTTensor, int, float],
402396
) -> TRTTensor:
403397
if isinstance(lhs_val, TRTTensor):
404-
lhs_val = trt_cast_int_or_float_to_bool(network, name, lhs_val)
398+
lhs_val = cast_int_or_float_to_bool(network, name, lhs_val)
405399

406400
if isinstance(rhs_val, TRTTensor):
407-
rhs_val = trt_cast_int_or_float_to_bool(network, name, rhs_val)
401+
rhs_val = cast_int_or_float_to_bool(network, name, rhs_val)
408402

409403
return convert_binary_elementwise(
410404
network, target, source_ir, name, trt.ElementWiseOperation.OR, lhs_val, rhs_val
@@ -420,10 +414,10 @@ def logical_xor(
420414
rhs_val: Union[TRTTensor, int, float],
421415
) -> TRTTensor:
422416
if isinstance(lhs_val, TRTTensor):
423-
lhs_val = trt_cast_int_or_float_to_bool(network, name, lhs_val)
417+
lhs_val = cast_int_or_float_to_bool(network, name, lhs_val)
424418

425419
if isinstance(rhs_val, TRTTensor):
426-
rhs_val = trt_cast_int_or_float_to_bool(network, name, rhs_val)
420+
rhs_val = cast_int_or_float_to_bool(network, name, rhs_val)
427421

428422
return convert_binary_elementwise(
429423
network, target, source_ir, name, trt.ElementWiseOperation.XOR, lhs_val, rhs_val

0 commit comments

Comments
 (0)