Skip to content

Commit 25c1ff4

Browse files
committed
delete int2float conversion for some ops
1 parent e7ed7de commit 25c1ff4

File tree

1 file changed

+0
-54
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/elementwise

1 file changed

+0
-54
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,6 @@ def add(
253253
lhs_val: Union[TRTTensor, int, float],
254254
rhs_val: Union[TRTTensor, int, float],
255255
) -> TRTTensor:
256-
if isinstance(lhs_val, TRTTensor):
257-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
258-
259-
if isinstance(rhs_val, TRTTensor):
260-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
261-
262256
return convert_binary_elementwise(
263257
network, target, source_ir, name, trt.ElementWiseOperation.SUM, lhs_val, rhs_val
264258
)
@@ -272,12 +266,6 @@ def mul(
272266
lhs_val: Union[TRTTensor, int, float],
273267
rhs_val: Union[TRTTensor, int, float],
274268
) -> TRTTensor:
275-
if isinstance(lhs_val, TRTTensor):
276-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
277-
278-
if isinstance(rhs_val, TRTTensor):
279-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
280-
281269
return convert_binary_elementwise(
282270
network,
283271
target,
@@ -297,12 +285,6 @@ def max(
297285
lhs_val: Union[TRTTensor, int, float],
298286
rhs_val: Union[TRTTensor, int, float],
299287
) -> TRTTensor:
300-
if isinstance(lhs_val, TRTTensor):
301-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
302-
303-
if isinstance(rhs_val, TRTTensor):
304-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
305-
306288
return convert_binary_elementwise(
307289
network, target, source_ir, name, trt.ElementWiseOperation.MAX, lhs_val, rhs_val
308290
)
@@ -316,12 +298,6 @@ def min(
316298
lhs_val: Union[TRTTensor, int, float],
317299
rhs_val: Union[TRTTensor, int, float],
318300
) -> TRTTensor:
319-
if isinstance(lhs_val, TRTTensor):
320-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
321-
322-
if isinstance(rhs_val, TRTTensor):
323-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
324-
325301
return convert_binary_elementwise(
326302
network, target, source_ir, name, trt.ElementWiseOperation.MIN, lhs_val, rhs_val
327303
)
@@ -335,12 +311,6 @@ def sub(
335311
lhs_val: Union[TRTTensor, int, float],
336312
rhs_val: Union[TRTTensor, int, float],
337313
) -> TRTTensor:
338-
if isinstance(lhs_val, TRTTensor):
339-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
340-
341-
if isinstance(rhs_val, TRTTensor):
342-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
343-
344314
return convert_binary_elementwise(
345315
network, target, source_ir, name, trt.ElementWiseOperation.SUB, lhs_val, rhs_val
346316
)
@@ -392,12 +362,6 @@ def floor_divide(
392362
lhs_val: Union[TRTTensor, int, float],
393363
rhs_val: Union[TRTTensor, int, float],
394364
) -> TRTTensor:
395-
if isinstance(lhs_val, TRTTensor):
396-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
397-
398-
if isinstance(rhs_val, TRTTensor):
399-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
400-
401365
return convert_binary_elementwise(
402366
network,
403367
target,
@@ -474,12 +438,6 @@ def eq(
474438
lhs_val: Union[TRTTensor, int, float],
475439
rhs_val: Union[TRTTensor, int, float],
476440
) -> TRTTensor:
477-
if isinstance(lhs_val, TRTTensor):
478-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
479-
480-
if isinstance(rhs_val, TRTTensor):
481-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
482-
483441
return convert_binary_elementwise(
484442
network,
485443
target,
@@ -499,12 +457,6 @@ def gt(
499457
lhs_val: Union[TRTTensor, int, float],
500458
rhs_val: Union[TRTTensor, int, float],
501459
) -> TRTTensor:
502-
if isinstance(lhs_val, TRTTensor):
503-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
504-
505-
if isinstance(rhs_val, TRTTensor):
506-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
507-
508460
return convert_binary_elementwise(
509461
network,
510462
target,
@@ -524,12 +476,6 @@ def lt(
524476
lhs_val: Union[TRTTensor, int, float],
525477
rhs_val: Union[TRTTensor, int, float],
526478
) -> TRTTensor:
527-
if isinstance(lhs_val, TRTTensor):
528-
lhs_val = trt_cast_int_to_float(network, name, lhs_val)
529-
530-
if isinstance(rhs_val, TRTTensor):
531-
rhs_val = trt_cast_int_to_float(network, name, rhs_val)
532-
533479
return convert_binary_elementwise(
534480
network,
535481
target,

0 commit comments

Comments
 (0)