5
5
from torch .fx .node import Target
6
6
from torch_tensorrt .dynamo ._SourceIR import SourceIR
7
7
from torch_tensorrt .dynamo .conversion .converter_utils import (
8
- trt_cast_int_or_float_to_bool ,
9
- trt_cast_int_to_float ,
8
+ cast_int_int_div_trt_tensor ,
9
+ cast_int_or_float_to_bool ,
10
10
)
11
11
from torch_tensorrt .dynamo .conversion .impl .elementwise .base import (
12
12
convert_binary_elementwise ,
@@ -324,11 +324,8 @@ def div(
324
324
lhs_val : Union [TRTTensor , int , float ],
325
325
rhs_val : Union [TRTTensor , int , float ],
326
326
) -> TRTTensor :
327
- if isinstance (lhs_val , TRTTensor ):
328
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
329
-
330
- if isinstance (rhs_val , TRTTensor ):
331
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
327
+ if isinstance (lhs_val , TRTTensor ) and isinstance (rhs_val , TRTTensor ):
328
+ lhs_val , rhs_val = cast_int_int_div_trt_tensor (network , lhs_val , rhs_val , name )
332
329
333
330
return convert_binary_elementwise (
334
331
network , target , source_ir , name , trt .ElementWiseOperation .DIV , lhs_val , rhs_val
@@ -343,11 +340,8 @@ def pow(
343
340
lhs_val : Union [TRTTensor , int , float ],
344
341
rhs_val : Union [TRTTensor , int , float ],
345
342
) -> TRTTensor :
346
- if isinstance (lhs_val , TRTTensor ):
347
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
348
-
349
- if isinstance (rhs_val , TRTTensor ):
350
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
343
+ if isinstance (lhs_val , TRTTensor ) and isinstance (rhs_val , TRTTensor ):
344
+ lhs_val , rhs_val = cast_int_int_div_trt_tensor (network , lhs_val , rhs_val , name )
351
345
352
346
return convert_binary_elementwise (
353
347
network , target , source_ir , name , trt .ElementWiseOperation .POW , lhs_val , rhs_val
@@ -382,10 +376,10 @@ def logical_and(
382
376
rhs_val : Union [TRTTensor , int , float ],
383
377
) -> TRTTensor :
384
378
if isinstance (lhs_val , TRTTensor ):
385
- lhs_val = trt_cast_int_or_float_to_bool (network , name , lhs_val )
379
+ lhs_val = cast_int_or_float_to_bool (network , name , lhs_val )
386
380
387
381
if isinstance (rhs_val , TRTTensor ):
388
- rhs_val = trt_cast_int_or_float_to_bool (network , name , rhs_val )
382
+ rhs_val = cast_int_or_float_to_bool (network , name , rhs_val )
389
383
390
384
return convert_binary_elementwise (
391
385
network , target , source_ir , name , trt .ElementWiseOperation .AND , lhs_val , rhs_val
@@ -401,10 +395,10 @@ def logical_or(
401
395
rhs_val : Union [TRTTensor , int , float ],
402
396
) -> TRTTensor :
403
397
if isinstance (lhs_val , TRTTensor ):
404
- lhs_val = trt_cast_int_or_float_to_bool (network , name , lhs_val )
398
+ lhs_val = cast_int_or_float_to_bool (network , name , lhs_val )
405
399
406
400
if isinstance (rhs_val , TRTTensor ):
407
- rhs_val = trt_cast_int_or_float_to_bool (network , name , rhs_val )
401
+ rhs_val = cast_int_or_float_to_bool (network , name , rhs_val )
408
402
409
403
return convert_binary_elementwise (
410
404
network , target , source_ir , name , trt .ElementWiseOperation .OR , lhs_val , rhs_val
@@ -420,10 +414,10 @@ def logical_xor(
420
414
rhs_val : Union [TRTTensor , int , float ],
421
415
) -> TRTTensor :
422
416
if isinstance (lhs_val , TRTTensor ):
423
- lhs_val = trt_cast_int_or_float_to_bool (network , name , lhs_val )
417
+ lhs_val = cast_int_or_float_to_bool (network , name , lhs_val )
424
418
425
419
if isinstance (rhs_val , TRTTensor ):
426
- rhs_val = trt_cast_int_or_float_to_bool (network , name , rhs_val )
420
+ rhs_val = cast_int_or_float_to_bool (network , name , rhs_val )
427
421
428
422
return convert_binary_elementwise (
429
423
network , target , source_ir , name , trt .ElementWiseOperation .XOR , lhs_val , rhs_val
0 commit comments