6
6
import torch
7
7
from torch .fx .node import Target
8
8
from torch_tensorrt .dynamo ._SourceIR import SourceIR
9
- from torch_tensorrt .dynamo .conversion .impl .elementwise .base import (
10
- convert_binary_elementwise ,
11
- )
12
- from torch_tensorrt .dynamo .conversion .impl .unary .base import convert_unary
9
+ from torch_tensorrt .dynamo .conversion import impl
10
+ from torch_tensorrt .dynamo .conversion .converter_utils import get_axes_for_reduce_op
13
11
from torch_tensorrt .fx .converters .converter_utils import (
14
12
get_positive_dim ,
15
13
get_trt_plugin ,
14
+ get_trt_tensor ,
16
15
has_dynamic_shape ,
17
16
set_layer_name ,
18
17
to_numpy ,
@@ -188,79 +187,77 @@ def layer_norm_no_plugin(
188
187
189
188
shape = weight .shape
190
189
broadcasted_shape = (1 ,) * (len (input .shape ) - len (shape )) + shape
191
- gamma = to_numpy (weight .reshape (* shape ) )
192
- beta = to_numpy (bias .reshape (* shape ) )
190
+ gamma = to_numpy (weight ) .reshape (shape )
191
+ beta = to_numpy (bias ) .reshape (shape )
193
192
194
- axes = 0
195
- for d in range (len (shape )):
196
- axes |= 1 << (len (input .shape ) - d - 1 )
193
+ dims = list (range (len (input .shape ) - len (shape ), len (input .shape )))
194
+ axes = get_axes_for_reduce_op (dims )
197
195
198
196
# E[x]
199
197
mean_expected_layer = network .add_reduce (
200
198
input , trt .ReduceOperation .AVG , axes , keep_dims = True
201
199
)
202
200
set_layer_name (mean_expected_layer , target , f"{ name } _mean_expected" , source_ir )
203
201
204
- # X- E[x]
205
- sub_trt = convert_binary_elementwise (
202
+ # X - E[x]
203
+ sub_trt = impl . elementwise . sub (
206
204
network ,
207
205
target ,
208
206
source_ir ,
209
207
f"{ name } _sub" ,
210
- trt .ElementWiseOperation .SUB ,
211
208
input ,
212
209
mean_expected_layer .get_output (0 ),
213
210
)
214
- # Variance = mean(pow(x_sub_mean,2))
211
+
212
+ # variance = mean(pow(x_sub_mean, 2))
215
213
pow_tensor = network .add_constant (
216
214
(1 ,) * len (input .shape ),
217
215
trt .Weights (np .ascontiguousarray ([2.0 ], dtype = np .float32 )),
218
216
)
219
217
pow_tensor .name = f"{ name } _power"
220
- pow_var = convert_binary_elementwise (
218
+ pow_var = impl . elementwise . pow (
221
219
network ,
222
220
target ,
223
221
source_ir ,
224
222
f"{ name } _pow_var" ,
225
- trt .ElementWiseOperation .POW ,
226
223
sub_trt ,
227
224
pow_tensor .get_output (0 ),
228
225
)
229
226
mean_trt_layer = network .add_reduce (
230
227
pow_var , trt .ReduceOperation .AVG , axes , keep_dims = True
231
228
)
232
229
set_layer_name (mean_trt_layer , target , f"{ name } _mean" , source_ir )
233
- # Variance + eps
230
+
231
+ # var + eps
234
232
eps_tensor = network .add_constant (
235
233
(1 ,) * len (input .shape ),
236
234
trt .Weights (np .ascontiguousarray ([eps ], dtype = np .float32 )),
237
235
)
238
236
eps_tensor .name = f"{ name } _eps"
239
- add_trt = convert_binary_elementwise (
237
+
238
+ # sqrt((var + eps))
239
+ add_trt = impl .elementwise .add (
240
240
network ,
241
241
target ,
242
242
source_ir ,
243
243
f"{ name } _add" ,
244
- trt .ElementWiseOperation .SUM ,
245
244
mean_trt_layer .get_output (0 ),
246
245
eps_tensor .get_output (0 ),
247
246
)
248
- # SQRT((Var + eps))
249
- sqrt_trt = convert_unary (
247
+ sqrt_trt = impl .unary .sqrt (
250
248
network ,
251
249
target ,
252
250
source_ir ,
253
251
f"{ name } _sqrt" ,
254
- trt .UnaryOperation .SQRT ,
255
252
add_trt ,
256
253
)
257
- # (x - E[x]) / sqrt((var + eps))
258
- div_trt = convert_binary_elementwise (
254
+
255
+ # (X - E[X]) / sqrt((var + eps))
256
+ div_trt = impl .elementwise .div (
259
257
network ,
260
258
target ,
261
259
source_ir ,
262
260
f"{ name } _div_trt" ,
263
- trt .ElementWiseOperation .DIV ,
264
261
sub_trt ,
265
262
sqrt_trt ,
266
263
)
@@ -270,32 +267,113 @@ def layer_norm_no_plugin(
270
267
gamma .shape , trt .Weights (np .ascontiguousarray (gamma ))
271
268
)
272
269
gamma_tensor .name = f"{ name } _gamma"
270
+
273
271
assert beta is not None
274
272
beta_tensor = network .add_constant (
275
273
gamma .shape , trt .Weights (np .ascontiguousarray (beta ))
276
274
)
277
275
beta_tensor .name = f"{ name } _beta"
276
+
278
277
# y * gamma + beta
279
- scale_layer = convert_binary_elementwise (
278
+ scaled_y = impl . elementwise . mul (
280
279
network ,
281
280
target ,
282
281
source_ir ,
283
282
f"{ name } _scale" ,
284
- trt .ElementWiseOperation .PROD ,
285
283
div_trt ,
286
284
gamma_tensor .get_output (0 ),
287
285
)
288
- return convert_binary_elementwise (
286
+ return impl . elementwise . add (
289
287
network ,
290
288
target ,
291
289
source_ir ,
292
290
name ,
293
- trt .ElementWiseOperation .SUM ,
294
- scale_layer ,
291
+ scaled_y ,
295
292
beta_tensor .get_output (0 ),
296
293
)
297
294
298
295
296
+ def native_group_norm (
297
+ network : TRTNetwork ,
298
+ target : Target ,
299
+ source_ir : Optional [SourceIR ],
300
+ name : str ,
301
+ input : TRTTensor ,
302
+ weight : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
303
+ bias : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
304
+ N : int ,
305
+ C : int ,
306
+ HxW : int ,
307
+ group : int ,
308
+ eps : float ,
309
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
310
+ return group_norm (
311
+ network ,
312
+ target ,
313
+ source_ir ,
314
+ name ,
315
+ input ,
316
+ group ,
317
+ weight ,
318
+ bias ,
319
+ eps ,
320
+ cudnn_enabled = True ,
321
+ )
322
+
323
+
324
+ def group_norm (
325
+ network : TRTNetwork ,
326
+ target : Target ,
327
+ source_ir : Optional [SourceIR ],
328
+ name : str ,
329
+ input : TRTTensor ,
330
+ num_groups : int ,
331
+ weight : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
332
+ bias : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
333
+ eps : float ,
334
+ cudnn_enabled : bool ,
335
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
336
+ if not isinstance (input , trt .tensorrt .ITensor ):
337
+ raise RuntimeError (
338
+ f"LayerNorm received input { input } that is not part "
339
+ "of the TensorRT region!"
340
+ )
341
+
342
+ if weight is None :
343
+ weight = to_numpy (1.0 )
344
+
345
+ if bias is None :
346
+ bias = to_numpy (0.0 )
347
+
348
+ scale = get_trt_tensor (network , weight , "scale" )
349
+ bias = get_trt_tensor (network , bias , "bias" )
350
+
351
+ eps_field = trt .PluginField (
352
+ "eps" , np .array (eps , dtype = np .float32 ), trt .PluginFieldType .FLOAT32
353
+ )
354
+ num_groups_filed = trt .PluginField (
355
+ "num_groups" , np .array (num_groups ), trt .PluginFieldType .INT32
356
+ )
357
+
358
+ field_collection = trt .PluginFieldCollection ([eps_field , num_groups_filed ])
359
+
360
+ try :
361
+ # Here's the schema of the plugin:
362
+ # https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml
363
+ plugin = get_trt_plugin ("GroupNormalizationPlugin" , field_collection , "1" )
364
+ except AssertionError :
365
+ _LOGGER .error (
366
+ "Unable to find group norm plugin, fall back to TensorRT implementation."
367
+ )
368
+
369
+ layer = network .add_plugin_v2 ([input , scale , bias ], plugin )
370
+ set_layer_name (layer , target , f"{ name } _GroupNormalizationPlugin" , source_ir )
371
+
372
+ # PyTorch requires three return values: (out, mean, rstd)
373
+ dummy_tensor = torch .tensor (0 )
374
+ return layer .get_output (0 ), dummy_tensor , dummy_tensor
375
+
376
+
299
377
def softmax (
300
378
network : TRTNetwork ,
301
379
target : Target ,
0 commit comments