@@ -253,12 +253,6 @@ def add(
253
253
lhs_val : Union [TRTTensor , int , float ],
254
254
rhs_val : Union [TRTTensor , int , float ],
255
255
) -> TRTTensor :
256
- if isinstance (lhs_val , TRTTensor ):
257
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
258
-
259
- if isinstance (rhs_val , TRTTensor ):
260
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
261
-
262
256
return convert_binary_elementwise (
263
257
network , target , source_ir , name , trt .ElementWiseOperation .SUM , lhs_val , rhs_val
264
258
)
@@ -272,12 +266,6 @@ def mul(
272
266
lhs_val : Union [TRTTensor , int , float ],
273
267
rhs_val : Union [TRTTensor , int , float ],
274
268
) -> TRTTensor :
275
- if isinstance (lhs_val , TRTTensor ):
276
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
277
-
278
- if isinstance (rhs_val , TRTTensor ):
279
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
280
-
281
269
return convert_binary_elementwise (
282
270
network ,
283
271
target ,
@@ -297,12 +285,6 @@ def max(
297
285
lhs_val : Union [TRTTensor , int , float ],
298
286
rhs_val : Union [TRTTensor , int , float ],
299
287
) -> TRTTensor :
300
- if isinstance (lhs_val , TRTTensor ):
301
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
302
-
303
- if isinstance (rhs_val , TRTTensor ):
304
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
305
-
306
288
return convert_binary_elementwise (
307
289
network , target , source_ir , name , trt .ElementWiseOperation .MAX , lhs_val , rhs_val
308
290
)
@@ -316,12 +298,6 @@ def min(
316
298
lhs_val : Union [TRTTensor , int , float ],
317
299
rhs_val : Union [TRTTensor , int , float ],
318
300
) -> TRTTensor :
319
- if isinstance (lhs_val , TRTTensor ):
320
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
321
-
322
- if isinstance (rhs_val , TRTTensor ):
323
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
324
-
325
301
return convert_binary_elementwise (
326
302
network , target , source_ir , name , trt .ElementWiseOperation .MIN , lhs_val , rhs_val
327
303
)
@@ -335,12 +311,6 @@ def sub(
335
311
lhs_val : Union [TRTTensor , int , float ],
336
312
rhs_val : Union [TRTTensor , int , float ],
337
313
) -> TRTTensor :
338
- if isinstance (lhs_val , TRTTensor ):
339
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
340
-
341
- if isinstance (rhs_val , TRTTensor ):
342
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
343
-
344
314
return convert_binary_elementwise (
345
315
network , target , source_ir , name , trt .ElementWiseOperation .SUB , lhs_val , rhs_val
346
316
)
@@ -392,12 +362,6 @@ def floor_divide(
392
362
lhs_val : Union [TRTTensor , int , float ],
393
363
rhs_val : Union [TRTTensor , int , float ],
394
364
) -> TRTTensor :
395
- if isinstance (lhs_val , TRTTensor ):
396
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
397
-
398
- if isinstance (rhs_val , TRTTensor ):
399
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
400
-
401
365
return convert_binary_elementwise (
402
366
network ,
403
367
target ,
@@ -474,12 +438,6 @@ def eq(
474
438
lhs_val : Union [TRTTensor , int , float ],
475
439
rhs_val : Union [TRTTensor , int , float ],
476
440
) -> TRTTensor :
477
- if isinstance (lhs_val , TRTTensor ):
478
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
479
-
480
- if isinstance (rhs_val , TRTTensor ):
481
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
482
-
483
441
return convert_binary_elementwise (
484
442
network ,
485
443
target ,
@@ -499,12 +457,6 @@ def gt(
499
457
lhs_val : Union [TRTTensor , int , float ],
500
458
rhs_val : Union [TRTTensor , int , float ],
501
459
) -> TRTTensor :
502
- if isinstance (lhs_val , TRTTensor ):
503
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
504
-
505
- if isinstance (rhs_val , TRTTensor ):
506
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
507
-
508
460
return convert_binary_elementwise (
509
461
network ,
510
462
target ,
@@ -524,12 +476,6 @@ def lt(
524
476
lhs_val : Union [TRTTensor , int , float ],
525
477
rhs_val : Union [TRTTensor , int , float ],
526
478
) -> TRTTensor :
527
- if isinstance (lhs_val , TRTTensor ):
528
- lhs_val = trt_cast_int_to_float (network , name , lhs_val )
529
-
530
- if isinstance (rhs_val , TRTTensor ):
531
- rhs_val = trt_cast_int_to_float (network , name , rhs_val )
532
-
533
479
return convert_binary_elementwise (
534
480
network ,
535
481
target ,
0 commit comments