|
1 | 1 | import logging
|
2 |
| -from typing import Any, List, Optional, Sequence, Tuple, Union, cast |
| 2 | +from typing import List, Optional, Sequence, Tuple, Union |
3 | 3 |
|
4 | 4 | import numpy as np
|
5 | 5 | import tensorrt as trt
|
|
16 | 16 | get_trt_tensor,
|
17 | 17 | has_dynamic_shape,
|
18 | 18 | set_layer_name,
|
19 |
| - to_numpy, |
20 | 19 | )
|
21 | 20 | from torch_tensorrt.dynamo.conversion.impl.cat import cat
|
22 | 21 | from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
|
@@ -204,240 +203,80 @@ def layer_norm(
|
204 | 203 | return layer_norm.get_output(0)
|
205 | 204 |
|
206 | 205 |
|
207 |
| -def native_group_norm( |
| 206 | +def group_norm( |
208 | 207 | ctx: ConversionContext,
|
209 | 208 | target: Target,
|
210 | 209 | source_ir: Optional[SourceIR],
|
211 | 210 | name: str,
|
212 | 211 | input: TRTTensor,
|
213 |
| - weight: Optional[Union[torch.Tensor, np.ndarray]], |
214 |
| - bias: Optional[Union[torch.Tensor, np.ndarray]], |
| 212 | + weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], |
| 213 | + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], |
215 | 214 | N: int,
|
216 | 215 | C: int,
|
217 | 216 | HxW: int,
|
218 | 217 | group: int,
|
219 | 218 | eps: float,
|
220 |
| - return_mean_rstd: bool = True, |
221 |
| -) -> Union[TRTTensor, Sequence[TRTTensor]]: |
222 |
| - # TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation |
223 |
| - # with INormalization Layer |
| 219 | +) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]: |
224 | 220 | assert (
|
225 | 221 | len(input.shape) >= 3
|
226 |
| - ), f"The input dimension should not be less than 3, got {len(input.shape)}!" |
| 222 | + ), f"Expected at least 3 dimensions for input tensor but got {len(input.shape)}" |
227 | 223 |
|
228 |
| - B = input.shape[0] |
229 |
| - # if C is provided, it must be as same as the channel from the input shape, |
230 |
| - # else if C is zero, we should get the channel from the input shape |
231 |
| - if C == 0: |
232 |
| - C = input.shape[1] |
233 | 224 | assert (
|
234 | 225 | C == input.shape[1]
|
235 |
| - ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}" |
236 |
| - # Groups are a subdivision of the channel dimension. |
237 |
| - assert ( |
238 |
| - C % group == 0 |
239 |
| - ), f"The num of channels ({C}) should be divisible by num_groups ({group})!" |
240 |
| - input = get_trt_tensor(ctx, input, f"{name}_input") |
241 |
| - |
242 |
| - shape = list(input.shape) |
243 |
| - |
244 |
| - for i, s in enumerate(shape): |
245 |
| - if i == 0 and s > 0: |
246 |
| - shape[i] = B * group |
247 |
| - elif i == 1: |
248 |
| - shape[i] = C // group |
249 |
| - elif i > 1 and s == -1: |
250 |
| - shape[i] = 0 |
251 |
| - |
252 |
| - # Normalize every group. |
253 |
| - reshaped_input = impl.shuffle.reshape( |
254 |
| - ctx, |
255 |
| - target, |
256 |
| - source_ir, |
257 |
| - f"{name}_reshape_input", |
258 |
| - input, |
259 |
| - shape, |
260 |
| - ) |
261 |
| - |
262 |
| - if weight is None: |
263 |
| - weight = to_numpy(1.0) |
| 226 | + ), f"num_channels ({C}) must be equal to number of channels in input ({input.shape[1]})" |
264 | 227 |
|
265 |
| - if bias is None: |
266 |
| - bias = to_numpy(0.0) |
| 228 | + weight_one = get_trt_tensor(ctx, 1.0, f"{name}_weight_one", input.dtype) |
| 229 | + bias_zero = get_trt_tensor(ctx, 0.0, f"{name}_bias_zero", input.dtype) |
267 | 230 |
|
268 |
| - weight = get_trt_tensor(ctx, weight, f"{name}_weight") |
269 |
| - bias = get_trt_tensor(ctx, bias, f"{name}_bias") |
270 |
| - weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) |
271 |
| - |
272 |
| - dims = list(range(1, len(input.shape))) |
273 |
| - |
274 |
| - # E[X] |
275 |
| - mean_trt = impl.reduce.mean( |
276 |
| - ctx, |
277 |
| - target, |
278 |
| - source_ir, |
279 |
| - f"{name}_mean", |
280 |
| - reshaped_input, |
281 |
| - dims, |
282 |
| - True, |
283 |
| - ) |
284 |
| - |
285 |
| - mean_trt = impl.slice.expand( |
286 |
| - ctx, |
287 |
| - target, |
288 |
| - source_ir, |
289 |
| - f"{name}_expand_mean_trt", |
290 |
| - mean_trt, |
291 |
| - reshaped_input.shape, |
292 |
| - ) |
293 |
| - |
294 |
| - # X - E[X] |
295 |
| - sub_trt = impl.elementwise.sub( |
296 |
| - ctx, |
297 |
| - target, |
298 |
| - source_ir, |
299 |
| - f"{name}_sub", |
300 |
| - reshaped_input, |
301 |
| - mean_trt, |
302 |
| - ) |
303 |
| - |
304 |
| - # variance |
305 |
| - pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) |
306 |
| - pow_var = impl.elementwise.pow( |
307 |
| - ctx, |
308 |
| - target, |
309 |
| - source_ir, |
310 |
| - f"{name}_pow", |
311 |
| - sub_trt, |
312 |
| - pow_trt, |
313 |
| - ) |
314 |
| - |
315 |
| - var_trt = impl.reduce.mean( |
316 |
| - ctx, |
317 |
| - target, |
318 |
| - source_ir, |
319 |
| - f"{name}_mean_var", |
320 |
| - pow_var, |
321 |
| - dims, |
322 |
| - True, |
323 |
| - ) |
324 |
| - |
325 |
| - var_trt = impl.slice.expand( |
326 |
| - ctx, |
327 |
| - target, |
328 |
| - source_ir, |
329 |
| - f"{name}_expand_var_trt", |
330 |
| - var_trt, |
331 |
| - reshaped_input.shape, |
332 |
| - ) |
333 |
| - |
334 |
| - eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) |
335 |
| - add_trt = impl.elementwise.add( |
336 |
| - ctx, |
337 |
| - target, |
338 |
| - source_ir, |
339 |
| - f"{name}_add", |
340 |
| - var_trt, |
341 |
| - eps_trt, |
342 |
| - ) |
| 231 | + shape = [1, group] + [1] * (len(input.shape) - 2) |
343 | 232 |
|
344 |
| - sqrt_trt = impl.unary.sqrt( |
345 |
| - ctx, |
346 |
| - target, |
347 |
| - source_ir, |
348 |
| - f"{name}_sqrt", |
349 |
| - add_trt, |
| 233 | + expanded_weight_one = impl.slice.expand( |
| 234 | + ctx, target, source_ir, f"{name}_expand_weight_one", weight_one, shape |
350 | 235 | )
|
351 |
| - |
352 |
| - # y = (X - E[X]) / sqrt((var + eps)) |
353 |
| - output = impl.elementwise.div( |
354 |
| - ctx, |
355 |
| - target, |
356 |
| - source_ir, |
357 |
| - f"{name}_div", |
358 |
| - sub_trt, |
359 |
| - sqrt_trt, |
| 236 | + expanded_bias_zero = impl.slice.expand( |
| 237 | + ctx, target, source_ir, f"{name}_expand_bias_zero", bias_zero, shape |
360 | 238 | )
|
361 | 239 |
|
362 |
| - shape = list(output.shape) |
363 |
| - for i, s in enumerate(shape): |
364 |
| - if i == 0 and s > 0: |
365 |
| - shape[i] = B |
366 |
| - elif i == 1: |
367 |
| - shape[i] = C |
368 |
| - elif i > 1 and s == -1: |
369 |
| - shape[i] = 0 |
| 240 | + axes = get_axes_for_reduce_op([i for i in range(2, len(input.shape))]) |
370 | 241 |
|
371 |
| - reshaped_output = impl.shuffle.reshape( |
372 |
| - ctx, target, source_ir, f"{name}_reshape_output", output, shape |
373 |
| - ) |
374 |
| - reshaped_gamma = impl.shuffle.reshape( |
375 |
| - ctx, |
376 |
| - target, |
377 |
| - source_ir, |
378 |
| - f"{name}_reshape_gamma", |
379 |
| - weight, |
380 |
| - weight_bias_shape, |
| 242 | + # INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel, |
| 243 | + # hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later |
| 244 | + layer = ctx.net.add_normalization( |
| 245 | + input, expanded_weight_one, expanded_bias_zero, axes |
381 | 246 | )
|
| 247 | + layer.epsilon = eps |
| 248 | + layer.num_groups = group |
| 249 | + set_layer_name(layer, target, name, source_ir) |
| 250 | + output = layer.get_output(0) |
382 | 251 |
|
383 |
| - reshaped_output = impl.elementwise.mul( |
384 |
| - ctx, |
385 |
| - target, |
386 |
| - source_ir, |
387 |
| - f"{name}_mul_gamma", |
388 |
| - reshaped_output, |
389 |
| - reshaped_gamma, |
390 |
| - ) |
| 252 | + shape[1] = C |
391 | 253 |
|
392 |
| - reshaped_bias = impl.shuffle.reshape( |
393 |
| - ctx, |
394 |
| - target, |
395 |
| - source_ir, |
396 |
| - f"{name}_reshape_beta", |
397 |
| - bias, |
398 |
| - weight_bias_shape, |
399 |
| - ) |
400 |
| - reshaped_output = impl.elementwise.add( |
401 |
| - ctx, |
402 |
| - target, |
403 |
| - source_ir, |
404 |
| - f"{name}_add_beta", |
405 |
| - reshaped_output, |
406 |
| - reshaped_bias, |
407 |
| - ) |
408 |
| - if return_mean_rstd: |
409 |
| - # return fake mean and rstd for now |
410 |
| - return reshaped_output, None, None |
411 |
| - return reshaped_output |
| 254 | + if weight is not None: |
| 255 | + weight = get_trt_tensor(ctx, weight, f"{name}_weight") |
| 256 | + weight = cast_trt_tensor( |
| 257 | + ctx, weight, input.dtype, f"{name}_cast_weight", target, source_ir |
| 258 | + ) |
| 259 | + weight = impl.shuffle.reshape( |
| 260 | + ctx, target, source_ir, f"{name}_reshape_weight", weight, shape |
| 261 | + ) |
| 262 | + output = impl.elementwise.mul( |
| 263 | + ctx, target, source_ir, f"{name}_mul_weight", output, weight |
| 264 | + ) |
412 | 265 |
|
| 266 | + if bias is not None: |
| 267 | + bias = get_trt_tensor(ctx, bias, f"{name}_bias") |
| 268 | + bias = cast_trt_tensor( |
| 269 | + ctx, bias, input.dtype, f"{name}_cast_bias", target, source_ir |
| 270 | + ) |
| 271 | + bias = impl.shuffle.reshape( |
| 272 | + ctx, target, source_ir, f"{name}_reshape_bias", bias, shape |
| 273 | + ) |
| 274 | + output = impl.elementwise.add( |
| 275 | + ctx, target, source_ir, f"{name}_add_bias", output, bias |
| 276 | + ) |
413 | 277 |
|
414 |
| -def group_norm( |
415 |
| - ctx: ConversionContext, |
416 |
| - target: Target, |
417 |
| - source_ir: Optional[SourceIR], |
418 |
| - name: str, |
419 |
| - input: TRTTensor, |
420 |
| - num_groups: int, |
421 |
| - weight: Optional[Union[torch.Tensor, np.ndarray]], |
422 |
| - bias: Optional[Union[torch.Tensor, np.ndarray]], |
423 |
| - eps: float, |
424 |
| - cudnn_enabled: bool, |
425 |
| -) -> Union[TRTTensor, Sequence[TRTTensor]]: |
426 |
| - return native_group_norm( |
427 |
| - ctx, |
428 |
| - target, |
429 |
| - source_ir, |
430 |
| - name, |
431 |
| - input, |
432 |
| - weight, |
433 |
| - bias, |
434 |
| - 0, |
435 |
| - 0, |
436 |
| - 0, |
437 |
| - num_groups, |
438 |
| - eps, |
439 |
| - return_mean_rstd=False, |
440 |
| - ) |
| 278 | + # return fake mean and rstd for now |
| 279 | + return output, None, None |
441 | 280 |
|
442 | 281 |
|
443 | 282 | def softmax(
|
|
0 commit comments