21
21
from ..utils import get_dynamic_dims , torch_dtype_from_trt , torch_dtype_to_trt
22
22
23
23
from .converter_utils import * # noqa: F403
24
+ from .activation import *
25
+ from .operator import *
26
+
24
27
import torch_tensorrt .fx .tracer .acc_tracer .acc_utils as acc_utils
25
- import activation
26
- import operator
28
+
27
29
28
30
_LOGGER : logging .Logger = logging .getLogger (__name__ )
29
31
@@ -40,7 +42,7 @@ def aten_ops_add(
40
42
"input" : args [0 ],
41
43
"other" : args [1 ],
42
44
}
43
- return operator . add_add (network , target , None , kwargs_new , name )
45
+ return add_add (network , target , None , kwargs_new , name )
44
46
45
47
46
48
@tensorrt_converter (torch .ops .aten .mean .dim )
@@ -143,13 +145,13 @@ def aten_ops_div(
143
145
}
144
146
rounding_mode = kwargs .get ("rounding_mode" )
145
147
if rounding_mode is None :
146
- return operator . add_div (network , target , None , kwargs_new , name )
148
+ return add_div (network , target , None , kwargs_new , name )
147
149
elif rounding_mode == "floor" :
148
- return operator . add_floor_div (
150
+ return add_floor_div (
149
151
network , target , None , kwargs_new , name
150
152
)
151
153
elif rounding_mode == "trunc" :
152
- return operator . add_trunc_div (
154
+ return add_trunc_div (
153
155
network , target , None , kwargs_new , name
154
156
)
155
157
else :
@@ -170,7 +172,7 @@ def aten_ops_floor_div(
170
172
"input" : args [0 ],
171
173
"other" : args [1 ],
172
174
}
173
- return operator . add_floor_div (network , target , None , kwargs_new , name )
175
+ return add_floor_div (network , target , None , kwargs_new , name )
174
176
175
177
176
178
@tensorrt_converter (torch .ops .aten .fmod .Scalar )
@@ -186,7 +188,7 @@ def aten_ops_fmod(
186
188
"input" : args [0 ],
187
189
"other" : args [1 ],
188
190
}
189
- return operator . add_fmod (network , target , None , kwargs_new , name )
191
+ return add_fmod (network , target , None , kwargs_new , name )
190
192
191
193
192
194
@tensorrt_converter (torch .ops .aten .linear )
@@ -203,7 +205,7 @@ def aten_ops_linear(
203
205
"bias" : args [2 ],
204
206
}
205
207
206
- return operator . add_linear (network , target , None , kwargs_new , name )
208
+ return add_linear (network , target , None , kwargs_new , name )
207
209
208
210
209
211
@tensorrt_converter (torch .ops .aten .max_pool3d )
@@ -252,10 +254,11 @@ def aten_ops_mul(
252
254
"input" : args [0 ],
253
255
"other" : args [1 ],
254
256
}
255
- return operator . add_mul (network , target , None , kwargs_new , name )
257
+ return add_mul (network , target , None , kwargs_new , name )
256
258
257
259
258
- @tensorrt_converter (torch .ops .aten .matmul .Tensor )
260
+ @tensorrt_converter (torch .ops .aten .matmul )
261
+ @tensorrt_converter (torch .ops .aten .mm .default )
259
262
def aten_ops_matmul (
260
263
network : TRTNetwork ,
261
264
target : Target ,
@@ -267,7 +270,7 @@ def aten_ops_matmul(
267
270
"input" : args [0 ],
268
271
"other" : args [1 ],
269
272
}
270
- return operator . add_matmul (network , target , None , kwargs_new , name )
273
+ return add_matmul (network , target , None , kwargs_new , name )
271
274
272
275
273
276
@tensorrt_converter (torch .ops .aten .pow .Tensor_Scalar )
@@ -283,7 +286,7 @@ def aten_ops_pow(
283
286
"input" : args [0 ],
284
287
"exponent" : args [1 ],
285
288
}
286
- return operator . add_pow (network , target , None , kwargs_new , name )
289
+ return add_pow (network , target , kwargs_new , name )
287
290
288
291
289
292
@tensorrt_converter (torch .ops .aten .relu .default )
@@ -297,7 +300,7 @@ def aten_ops_relu(
297
300
kwargs_new = {
298
301
"input" : args [0 ],
299
302
}
300
- return activation . add_relu (network , target , kwargs_new , name )
303
+ return add_relu (network , target , kwargs_new , name )
301
304
302
305
@tensorrt_converter (torch .ops .aten .sub .Tensor )
303
306
def aten_ops_sub (
@@ -311,7 +314,7 @@ def aten_ops_sub(
311
314
"input" : args [0 ],
312
315
"other" : args [1 ],
313
316
}
314
- return operator . add_sub (network , target , None , kwargs_new , name )
317
+ return add_sub (network , target , None , kwargs_new , name )
315
318
316
319
317
320
@tensorrt_converter (torch .ops .aten .view .default )
@@ -378,7 +381,7 @@ def aten_ops_expand(
378
381
"input" : args [0 ],
379
382
"sizes" : args [1 ],
380
383
}
381
- return operator . add_expand (network , target , kwargs_new , name )
384
+ return add_expand (network , target , kwargs_new , name )
382
385
383
386
384
387
@tensorrt_converter (operator .floordiv )
@@ -393,7 +396,7 @@ def aten_ops_operator_floordiv(
393
396
"input" : args [0 ],
394
397
"other" : args [1 ],
395
398
}
396
- return operator . add_floor_div (network , target , None , kwargs_new , name )
399
+ return add_floor_div (network , target , None , kwargs_new , name )
397
400
398
401
399
402
@tensorrt_converter (operator .mul )
@@ -408,7 +411,7 @@ def aten_ops_operator_mul(
408
411
"input" : args [0 ],
409
412
"other" : args [1 ],
410
413
}
411
- return operator . add_mul (network , target , None , kwargs_new , name )
414
+ return add_mul (network , target , None , kwargs_new , name )
412
415
413
416
414
417
@tensorrt_converter (operator .add )
@@ -423,7 +426,7 @@ def aten_ops_operator_add(
423
426
"input" : args [0 ],
424
427
"other" : args [1 ],
425
428
}
426
- return operator . add_add (network , target , None , kwargs_new , name )
429
+ return add_add (network , target , None , kwargs_new , name )
427
430
428
431
429
432
@tensorrt_converter (operator .sub )
@@ -438,7 +441,7 @@ def aten_ops_operator_sub(
438
441
"input" : args [0 ],
439
442
"other" : args [1 ],
440
443
}
441
- return operator . add_sub (network , target , None , kwargs_new , name )
444
+ return add_sub (network , target , None , kwargs_new , name )
442
445
443
446
444
447
@tensorrt_converter (torch .ops .aten .sym_numel )
@@ -497,9 +500,10 @@ def aten_ops_slice(
497
500
"stop" : args [3 ],
498
501
"step" : args [4 ],
499
502
}
500
- return operator . add_slice (network , target . kwargs_new , name )
503
+ return add_slice (network , target . kwargs_new , name )
501
504
502
- @tensorrt_converter (torch .ops .aten .select .Tensor )
505
+
506
+ @tensorrt_converter (torch .ops .aten .select )
503
507
def aten_ops_select (
504
508
network : TRTNetwork ,
505
509
target : Target ,
@@ -512,7 +516,7 @@ def aten_ops_select(
512
516
"dim" : args [1 ],
513
517
"index" : args [2 ],
514
518
}
515
- return operator . add_select (network , target . kwargs_new , name )
519
+ return add_select (network , target . kwargs_new , name )
516
520
517
521
518
522
@tensorrt_converter (torch .ops .aten .leaky_relu .default )
@@ -526,7 +530,7 @@ def aten_ops_leaky_relu(
526
530
kwargs_new = {
527
531
"input" : args [0 ],
528
532
}
529
- return activation . add_leaky_relu (network , target , kwargs_new , name )
533
+ return add_leaky_relu (network , target , kwargs_new , name )
530
534
531
535
532
536
@tensorrt_converter (torch .ops .aten .elu .default )
@@ -540,7 +544,7 @@ def aten_ops_elu(
540
544
kwargs_new = {
541
545
"input" : args [0 ],
542
546
}
543
- return activation . add_elu (network , target , kwargs_new , name )
547
+ return add_elu (network , target , kwargs_new , name )
544
548
545
549
546
550
@tensorrt_converter (torch .ops .aten .selu .default )
@@ -554,7 +558,7 @@ def aten_ops_selu(
554
558
kwargs_new = {
555
559
"input" : args [0 ],
556
560
}
557
- return activation . selu (network , target , kwargs_new , name )
561
+ return add_selu (network , target , kwargs_new , name )
558
562
559
563
560
564
@tensorrt_converter (torch .ops .aten .gelu .default )
@@ -568,22 +572,7 @@ def aten_ops_gelu(
568
572
kwargs_new = {
569
573
"input" : args [0 ],
570
574
}
571
- return activation .add_gelu (network , target , kwargs_new , name )
572
-
573
-
574
- @tensorrt_converter (torch .ops .aten .softsign .default )
575
- def aten_ops_softsign (
576
- network : TRTNetwork ,
577
- target : Target ,
578
- args : Tuple [Argument , ...],
579
- kwargs : Dict [str , Argument ],
580
- name : str ,
581
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
582
- kwargs_new = {
583
- "input" : args [0 ],
584
- }
585
- return activation .add_softsign (network , target , kwargs_new , name )
586
-
575
+ return add_gelu (network , target , kwargs_new , name )
587
576
588
577
@tensorrt_converter (torch .ops .aten .tanh .default )
589
578
def aten_ops_tanh (
@@ -596,34 +585,7 @@ def aten_ops_tanh(
596
585
kwargs_new = {
597
586
"input" : args [0 ],
598
587
}
599
- return activation .add_tanh (network , target , kwargs_new , name )
600
-
601
- @tensorrt_converter (torch .ops .aten .softsign .default )
602
- def aten_ops_softsign (
603
- network : TRTNetwork ,
604
- target : Target ,
605
- args : Tuple [Argument , ...],
606
- kwargs : Dict [str , Argument ],
607
- name : str ,
608
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
609
- kwargs_new = {
610
- "input" : args [0 ],
611
- }
612
- return activation .add_softsign (network , target , kwargs_new , name )
613
-
614
-
615
- @tensorrt_converter (torch .ops .aten .softsign .default )
616
- def aten_ops_hard_sigmoid (
617
- network : TRTNetwork ,
618
- target : Target ,
619
- args : Tuple [Argument , ...],
620
- kwargs : Dict [str , Argument ],
621
- name : str ,
622
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
623
- kwargs_new = {
624
- "input" : args [0 ],
625
- }
626
- return activation .add_hard_sigmoid (network , target , kwargs_new , name )
588
+ return add_tanh (network , target , kwargs_new , name )
627
589
628
590
629
591
@tensorrt_converter (torch .ops .aten .sigmoid .default )
@@ -637,7 +599,7 @@ def aten_ops_hard_tanh(
637
599
kwargs_new = {
638
600
"input" : args [0 ],
639
601
}
640
- return activation . add_hard_tanh (network , target , kwargs_new , name )
602
+ return add_hard_tanh (network , target , kwargs_new , name )
641
603
642
604
643
605
@tensorrt_converter (torch .ops .aten .sigmoid .default )
@@ -651,7 +613,7 @@ def aten_ops_sigmoid(
651
613
kwargs_new = {
652
614
"input" : args [0 ],
653
615
}
654
- return activation . add_sigmoid (network , target , kwargs_new , name )
616
+ return add_sigmoid (network , target , kwargs_new , name )
655
617
656
618
657
619
0 commit comments