@@ -46,7 +46,23 @@ def get_ir(target: Target) -> SourceIR:
46
46
return SourceIR .UNKNOWN
47
47
48
48
49
+ def one_user_validator (node : Node ) -> bool :
50
+ # Validate only one user, which is a getitem node that accesses the first element in the list
51
+ return (
52
+ len (node .users ) == 1
53
+ and list (node .users )[0 ].target == operator .getitem
54
+ and list (node .users )[0 ].args [1 ] == 0
55
+ )
56
+
57
+
58
+ @dynamo_tensorrt_converter (torch .ops .aten .native_batch_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
59
+ @dynamo_tensorrt_converter (torch .ops .aten .batch_norm .default ) # type: ignore[misc]
49
60
@dynamo_tensorrt_converter (torch .ops .aten .batch_norm ) # type: ignore[misc]
61
+ @enforce_tensor_types (
62
+ {
63
+ 0 : (TRTTensor ,),
64
+ }
65
+ ) # type: ignore[misc]
50
66
def aten_ops_batch_norm (
51
67
ctx : ConversionContext ,
52
68
target : Target ,
@@ -59,14 +75,103 @@ def aten_ops_batch_norm(
59
75
target ,
60
76
SourceIR .ATEN ,
61
77
name ,
62
- args [0 ],
63
- args [1 ],
64
- args [2 ],
65
- args [3 ],
66
- args [4 ],
67
- args [5 ],
68
- args [6 ],
69
- args [7 ],
78
+ input = args [0 ],
79
+ weight = args [1 ],
80
+ bias = args [2 ],
81
+ running_mean = args [3 ],
82
+ running_var = args [4 ],
83
+ training = args [5 ],
84
+ momentum = args [6 ],
85
+ eps = args [7 ],
86
+ cudnn_enabled = args_bounds_check (args , 8 , True ),
87
+ return_mean_rstd = (target == torch .ops .aten .native_batch_norm .default ),
88
+ )
89
+
90
+
91
+ @dynamo_tensorrt_converter (torch .ops .aten .native_layer_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
92
+ @dynamo_tensorrt_converter (torch .ops .aten .layer_norm .default ) # type: ignore[misc]
93
+ @dynamo_tensorrt_converter (torch .ops .aten .layer_norm ) # type: ignore[misc]
94
+ @enforce_tensor_types (
95
+ {
96
+ 0 : (TRTTensor ,),
97
+ }
98
+ ) # type: ignore[misc]
99
+ def aten_ops_layer_norm (
100
+ ctx : ConversionContext ,
101
+ target : Target ,
102
+ args : Tuple [Argument , ...],
103
+ kwargs : Dict [str , Argument ],
104
+ name : str ,
105
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
106
+ return impl .normalization .layer_norm (
107
+ ctx ,
108
+ target ,
109
+ SourceIR .ATEN ,
110
+ name ,
111
+ input = args [0 ],
112
+ normalized_shape = args [1 ],
113
+ weight = args_bounds_check (args , 2 ),
114
+ bias = args_bounds_check (args , 3 ),
115
+ eps = args_bounds_check (args , 4 , 1e-05 ),
116
+ cudnn_enable = args_bounds_check (args , 5 , True ),
117
+ return_mean_rstd = (target == torch .ops .aten .native_layer_norm .default ),
118
+ )
119
+
120
+
121
+ @dynamo_tensorrt_converter (torch .ops .aten .native_group_norm .default , capability_validator = one_user_validator ) # type: ignore[misc]
122
+ @enforce_tensor_types (
123
+ {
124
+ 0 : (TRTTensor ,),
125
+ }
126
+ ) # type: ignore[misc]
127
+ def aten_ops_native_group_norm (
128
+ ctx : ConversionContext ,
129
+ target : Target ,
130
+ args : Tuple [Argument , ...],
131
+ kwargs : Dict [str , Argument ],
132
+ name : str ,
133
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
134
+ return impl .normalization .native_group_norm (
135
+ ctx ,
136
+ target ,
137
+ SourceIR .ATEN ,
138
+ name ,
139
+ input = args [0 ],
140
+ weight = args [1 ],
141
+ bias = args [2 ],
142
+ N = args [3 ],
143
+ C = args [4 ],
144
+ HxW = args [5 ],
145
+ group = args [6 ],
146
+ eps = args [7 ],
147
+ )
148
+
149
+
150
+ @dynamo_tensorrt_converter (torch .ops .aten .group_norm .default ) # type: ignore[misc]
151
+ @dynamo_tensorrt_converter (torch .ops .aten .group_norm ) # type: ignore[misc]
152
+ @enforce_tensor_types (
153
+ {
154
+ 0 : (TRTTensor ,),
155
+ }
156
+ ) # type: ignore[misc]
157
+ def aten_ops_group_norm (
158
+ ctx : ConversionContext ,
159
+ target : Target ,
160
+ args : Tuple [Argument , ...],
161
+ kwargs : Dict [str , Argument ],
162
+ name : str ,
163
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
164
+ return impl .normalization .group_norm (
165
+ ctx ,
166
+ target ,
167
+ SourceIR .ATEN ,
168
+ name ,
169
+ input = args [0 ],
170
+ num_groups = args [1 ],
171
+ weight = args_bounds_check (args , 2 , None ),
172
+ bias = args_bounds_check (args , 3 , None ),
173
+ eps = args_bounds_check (args , 4 , 1e-05 ),
174
+ cudnn_enabled = args_bounds_check (args , 5 , True ),
70
175
)
71
176
72
177
@@ -328,27 +433,6 @@ def aten_ops_matmul(
328
433
)
329
434
330
435
331
- @dynamo_tensorrt_converter (torch .ops .aten .layer_norm .default ) # type: ignore[misc]
332
- def aten_ops_layernorm (
333
- ctx : ConversionContext ,
334
- target : Target ,
335
- args : Tuple [Argument , ...],
336
- kwargs : Dict [str , Argument ],
337
- name : str ,
338
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
339
- return impl .normalization .layer_norm (
340
- ctx ,
341
- target ,
342
- SourceIR .ATEN ,
343
- name ,
344
- args [0 ],
345
- args [1 ],
346
- args [2 ],
347
- args [3 ],
348
- args [4 ],
349
- )
350
-
351
-
352
436
@dynamo_tensorrt_converter (torch .ops .aten .rsqrt .default ) # type: ignore[misc]
353
437
def aten_ops_rsqrt (
354
438
ctx : ConversionContext ,
@@ -763,15 +847,6 @@ def aten_ops_prod(
763
847
)
764
848
765
849
766
- def one_user_validator (node : Node ) -> bool :
767
- # Validate only one user, which is a getitem node that accesses the first element in the list
768
- return (
769
- len (node .users ) == 1
770
- and list (node .users )[0 ].target == operator .getitem
771
- and list (node .users )[0 ].args [1 ] == 0
772
- )
773
-
774
-
775
850
@dynamo_tensorrt_converter (torch .ops .aten .max .default ) # type: ignore[misc]
776
851
@dynamo_tensorrt_converter (torch .ops .aten .max .dim , capability_validator = one_user_validator ) # type: ignore[misc]
777
852
def aten_ops_max (
0 commit comments