@@ -47,6 +47,29 @@ def get_ir(target: Target) -> SourceIR:
47
47
48
48
49
49
@dynamo_tensorrt_converter (torch .ops .aten .native_batch_norm .default ) # type: ignore[misc]
50
+ def aten_ops_native_batch_norm (
51
+ ctx : ConversionContext ,
52
+ target : Target ,
53
+ args : Tuple [Argument , ...],
54
+ kwargs : Dict [str , Argument ],
55
+ name : str ,
56
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
57
+ return impl .normalization .native_batch_norm (
58
+ ctx ,
59
+ target ,
60
+ SourceIR .ATEN ,
61
+ name ,
62
+ input = args [0 ],
63
+ weight = args [1 ],
64
+ bias = args [2 ],
65
+ running_mean = args [3 ],
66
+ running_var = args [4 ],
67
+ training = args [5 ],
68
+ momentum = args [6 ],
69
+ eps = args [7 ],
70
+ )
71
+
72
+
50
73
@dynamo_tensorrt_converter (torch .ops .aten .batch_norm ) # type: ignore[misc]
51
74
def aten_ops_batch_norm (
52
75
ctx : ConversionContext ,
@@ -68,20 +91,19 @@ def aten_ops_batch_norm(
68
91
training = args [5 ],
69
92
momentum = args [6 ],
70
93
eps = args [7 ],
71
- cudnn_enabled = args_bounds_check ( args , 8 , replacement = True ) ,
94
+ cudnn_enabled = args [ 8 ] ,
72
95
)
73
96
74
97
75
98
@dynamo_tensorrt_converter (torch .ops .aten .native_layer_norm .default ) # type: ignore[misc]
76
- @dynamo_tensorrt_converter (torch .ops .aten .layer_norm .default ) # type: ignore[misc]
77
- def aten_ops_layer_norm (
99
+ def aten_ops_native_layer_norm (
78
100
ctx : ConversionContext ,
79
101
target : Target ,
80
102
args : Tuple [Argument , ...],
81
103
kwargs : Dict [str , Argument ],
82
104
name : str ,
83
105
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
84
- return impl .normalization .layer_norm (
106
+ return impl .normalization .native_layer_norm (
85
107
ctx ,
86
108
target ,
87
109
SourceIR .ATEN ,
@@ -91,7 +113,74 @@ def aten_ops_layer_norm(
91
113
weight = args [2 ],
92
114
bias = args [3 ],
93
115
eps = args [4 ],
94
- cudnn_enable = args_bounds_check (args , 5 , replacement = True ),
116
+ )
117
+
118
+
119
+ @dynamo_tensorrt_converter (torch .ops .aten .layer_norm .default ) # type: ignore[misc]
120
+ def aten_ops_layer_norm (
121
+ ctx : ConversionContext ,
122
+ target : Target ,
123
+ args : Tuple [Argument , ...],
124
+ kwargs : Dict [str , Argument ],
125
+ name : str ,
126
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
127
+ return impl .normalization .layer_norm (
128
+ ctx ,
129
+ target ,
130
+ SourceIR .ATEN ,
131
+ name ,
132
+ input = args [0 ],
133
+ normalized_shape = args [1 ],
134
+ weight = args_bounds_check (args , 2 ),
135
+ bias = args_bounds_check (args , 3 ),
136
+ eps = args_bounds_check (args , 4 , 1e-05 ),
137
+ cudnn_enable = args_bounds_check (args , 5 , True ),
138
+ )
139
+
140
+
141
+ @dynamo_tensorrt_converter (torch .ops .aten .native_group_norm .default ) # type: ignore[misc]
142
+ def aten_ops_native_group_norm (
143
+ ctx : ConversionContext ,
144
+ target : Target ,
145
+ args : Tuple [Argument , ...],
146
+ kwargs : Dict [str , Argument ],
147
+ name : str ,
148
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
149
+ return impl .normalization .native_group_norm (
150
+ ctx ,
151
+ target ,
152
+ SourceIR .ATEN ,
153
+ name ,
154
+ input = args [0 ],
155
+ weight = args [1 ],
156
+ bias = args [2 ],
157
+ N = args [3 ],
158
+ C = args [4 ],
159
+ HxW = args [5 ],
160
+ group = args [6 ],
161
+ eps = args [7 ],
162
+ )
163
+
164
+
165
+ @dynamo_tensorrt_converter (torch .ops .aten .group_norm .default ) # type: ignore[misc]
166
+ def aten_ops_group_norm (
167
+ ctx : ConversionContext ,
168
+ target : Target ,
169
+ args : Tuple [Argument , ...],
170
+ kwargs : Dict [str , Argument ],
171
+ name : str ,
172
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
173
+ return impl .normalization .group_norm (
174
+ ctx ,
175
+ target ,
176
+ SourceIR .ATEN ,
177
+ name ,
178
+ input = args [0 ],
179
+ num_groups = args [1 ],
180
+ weight = args_bounds_check (args , 2 , None ),
181
+ bias = args_bounds_check (args , 3 , None ),
182
+ eps = args_bounds_check (args , 4 , 1e-05 ),
183
+ cudnn_enabled = args_bounds_check (args , 5 , True ),
95
184
)
96
185
97
186
0 commit comments