|
1 | 1 | import logging
|
2 |
| -from typing import Any, List, Optional, Sequence, Tuple, Union, cast |
| 2 | +from typing import List, Optional, Sequence, Tuple, Union |
3 | 3 |
|
4 | 4 | import numpy as np
|
5 | 5 | import tensorrt as trt
|
|
16 | 16 | get_trt_tensor,
|
17 | 17 | has_dynamic_shape,
|
18 | 18 | set_layer_name,
|
19 |
| - to_numpy, |
20 | 19 | )
|
21 | 20 | from torch_tensorrt.dynamo.conversion.impl.cat import cat
|
22 | 21 | from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
|
@@ -203,234 +202,72 @@ def native_group_norm(
|
203 | 202 | source_ir: Optional[SourceIR],
|
204 | 203 | name: str,
|
205 | 204 | input: TRTTensor,
|
206 |
| - weight: Optional[Union[torch.Tensor, np.ndarray]], |
207 |
| - bias: Optional[Union[torch.Tensor, np.ndarray]], |
| 205 | + weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], |
| 206 | + bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], |
208 | 207 | N: int,
|
209 | 208 | C: int,
|
210 | 209 | HxW: int,
|
211 | 210 | group: int,
|
212 | 211 | eps: float,
|
213 |
| - return_mean_rstd: bool = True, |
214 |
| -) -> Union[TRTTensor, Sequence[TRTTensor]]: |
215 |
| - # TODO: Ask TRT team about the usage of INormalization Layer usage with num_groups and update the implementation |
216 |
| - # with INormalization Layer |
217 |
| - assert ( |
218 |
| - len(input.shape) >= 3 |
219 |
| - ), f"The input dimension should not be less than 3, got {len(input.shape)}!" |
220 |
| - |
221 |
| - B = input.shape[0] |
222 |
| - # if C is provided, it must be as same as the channel from the input shape, |
223 |
| - # else if C is zero, we should get the channel from the input shape |
224 |
| - if C == 0: |
225 |
| - C = input.shape[1] |
226 |
| - assert ( |
227 |
| - C == input.shape[1] |
228 |
| - ), f"The number of Channel={C} must be equal to the number of channels in the input shape={input.shape[1]}" |
229 |
| - # Groups are a subdivision of the channel dimension. |
230 |
| - assert ( |
231 |
| - C % group == 0 |
232 |
| - ), f"The num of channels ({C}) should be divisible by num_groups ({group})!" |
233 |
| - input = get_trt_tensor(ctx, input, f"{name}_input") |
234 |
| - |
235 |
| - shape = list(input.shape) |
236 |
| - |
237 |
| - for i, s in enumerate(shape): |
238 |
| - if i == 0 and s > 0: |
239 |
| - shape[i] = B * group |
240 |
| - elif i == 1: |
241 |
| - shape[i] = C // group |
242 |
| - elif i > 1 and s == -1: |
243 |
| - shape[i] = 0 |
244 |
| - |
245 |
| - # Normalize every group. |
246 |
| - reshaped_input = impl.shuffle.reshape( |
247 |
| - ctx, |
248 |
| - target, |
249 |
| - source_ir, |
250 |
| - f"{name}_reshape_input", |
251 |
| - input, |
252 |
| - shape, |
253 |
| - ) |
254 |
| - |
255 |
| - if weight is None: |
256 |
| - weight = to_numpy(1.0) |
257 |
| - |
258 |
| - if bias is None: |
259 |
| - bias = to_numpy(0.0) |
260 |
| - |
261 |
| - weight = get_trt_tensor(ctx, weight, f"{name}_weight") |
262 |
| - bias = get_trt_tensor(ctx, bias, f"{name}_bias") |
263 |
| - weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2) |
264 |
| - |
265 |
| - dims = list(range(1, len(input.shape))) |
266 |
| - |
267 |
| - # E[X] |
268 |
| - mean_trt = impl.reduce.mean( |
269 |
| - ctx, |
270 |
| - target, |
271 |
| - source_ir, |
272 |
| - f"{name}_mean", |
273 |
| - reshaped_input, |
274 |
| - dims, |
275 |
| - True, |
276 |
| - ) |
| 212 | +) -> Tuple[TRTTensor, torch.Tensor, torch.Tensor]: |
| 213 | + rank = len(input.shape) |
277 | 214 |
|
278 |
| - mean_trt = impl.slice.expand( |
279 |
| - ctx, |
280 |
| - target, |
281 |
| - source_ir, |
282 |
| - f"{name}_expand_mean_trt", |
283 |
| - mean_trt, |
284 |
| - reshaped_input.shape, |
285 |
| - ) |
| 215 | + assert rank >= 3, f"Expected at least 3 dimensions for input tensor but got {rank}" |
286 | 216 |
|
287 |
| - # X - E[X] |
288 |
| - sub_trt = impl.elementwise.sub( |
289 |
| - ctx, |
290 |
| - target, |
291 |
| - source_ir, |
292 |
| - f"{name}_sub", |
293 |
| - reshaped_input, |
294 |
| - mean_trt, |
295 |
| - ) |
| 217 | + assert ( |
| 218 | + C == input.shape[1] |
| 219 | + ), f"num_channels ({C}) must be equal to number of channels in input ({input.shape[1]})" |
296 | 220 |
|
297 |
| - # variance |
298 |
| - pow_trt = get_trt_tensor(ctx, 2, f"{name}_power", np.float32) |
299 |
| - pow_var = impl.elementwise.pow( |
300 |
| - ctx, |
301 |
| - target, |
302 |
| - source_ir, |
303 |
| - f"{name}_pow", |
304 |
| - sub_trt, |
305 |
| - pow_trt, |
306 |
| - ) |
| 221 | + weight_one = get_trt_tensor(ctx, 1.0, f"{name}_weight_one", input.dtype) |
| 222 | + bias_zero = get_trt_tensor(ctx, 0.0, f"{name}_bias_zero", input.dtype) |
307 | 223 |
|
308 |
| - var_trt = impl.reduce.mean( |
309 |
| - ctx, |
310 |
| - target, |
311 |
| - source_ir, |
312 |
| - f"{name}_mean_var", |
313 |
| - pow_var, |
314 |
| - dims, |
315 |
| - True, |
316 |
| - ) |
| 224 | + shape = [1, group] + [1] * (rank - 2) |
317 | 225 |
|
318 |
| - var_trt = impl.slice.expand( |
319 |
| - ctx, |
320 |
| - target, |
321 |
| - source_ir, |
322 |
| - f"{name}_expand_var_trt", |
323 |
| - var_trt, |
324 |
| - reshaped_input.shape, |
| 226 | + weight_one = impl.slice.expand( |
| 227 | + ctx, target, source_ir, f"{name}_expand_weight_one", weight_one, shape |
325 | 228 | )
|
326 |
| - |
327 |
| - eps_trt = get_trt_tensor(ctx, eps, f"{name}_eps", np.float32) |
328 |
| - add_trt = impl.elementwise.add( |
329 |
| - ctx, |
330 |
| - target, |
331 |
| - source_ir, |
332 |
| - f"{name}_add", |
333 |
| - var_trt, |
334 |
| - eps_trt, |
| 229 | + bias_zero = impl.slice.expand( |
| 230 | + ctx, target, source_ir, f"{name}_expand_bias_zero", bias_zero, shape |
335 | 231 | )
|
336 | 232 |
|
337 |
| - sqrt_trt = impl.unary.sqrt( |
338 |
| - ctx, |
339 |
| - target, |
340 |
| - source_ir, |
341 |
| - f"{name}_sqrt", |
342 |
| - add_trt, |
343 |
| - ) |
| 233 | + axes = get_axes_for_reduce_op([i for i in range(1 if group == 1 else 2, rank)]) |
344 | 234 |
|
345 |
| - # y = (X - E[X]) / sqrt((var + eps)) |
346 |
| - output = impl.elementwise.div( |
347 |
| - ctx, |
348 |
| - target, |
349 |
| - source_ir, |
350 |
| - f"{name}_div", |
351 |
| - sub_trt, |
352 |
| - sqrt_trt, |
353 |
| - ) |
| 235 | + # INormalizationLayer scales the normalized output per-group, but PyTorch scales the normalized output per-channel, |
| 236 | + # hence causing diverse result. Let TensorRT does no-op for scaling here, and do scaling ourselves later |
| 237 | + layer = ctx.net.add_normalization(input, weight_one, bias_zero, axes) |
| 238 | + layer.epsilon = eps |
| 239 | + layer.num_groups = group |
| 240 | + set_layer_name(layer, target, name, source_ir) |
| 241 | + output = layer.get_output(0) |
354 | 242 |
|
355 |
| - shape = list(output.shape) |
356 |
| - for i, s in enumerate(shape): |
357 |
| - if i == 0 and s > 0: |
358 |
| - shape[i] = B |
359 |
| - elif i == 1: |
360 |
| - shape[i] = C |
361 |
| - elif i > 1 and s == -1: |
362 |
| - shape[i] = 0 |
| 243 | + shape[1] = C |
363 | 244 |
|
364 |
| - reshaped_output = impl.shuffle.reshape( |
365 |
| - ctx, target, source_ir, f"{name}_reshape_output", output, shape |
366 |
| - ) |
367 |
| - reshaped_gamma = impl.shuffle.reshape( |
368 |
| - ctx, |
369 |
| - target, |
370 |
| - source_ir, |
371 |
| - f"{name}_reshape_gamma", |
372 |
| - weight, |
373 |
| - weight_bias_shape, |
374 |
| - ) |
375 |
| - |
376 |
| - reshaped_output = impl.elementwise.mul( |
377 |
| - ctx, |
378 |
| - target, |
379 |
| - source_ir, |
380 |
| - f"{name}_mul_gamma", |
381 |
| - reshaped_output, |
382 |
| - reshaped_gamma, |
383 |
| - ) |
384 |
| - |
385 |
| - reshaped_bias = impl.shuffle.reshape( |
386 |
| - ctx, |
387 |
| - target, |
388 |
| - source_ir, |
389 |
| - f"{name}_reshape_beta", |
390 |
| - bias, |
391 |
| - weight_bias_shape, |
392 |
| - ) |
393 |
| - reshaped_output = impl.elementwise.add( |
394 |
| - ctx, |
395 |
| - target, |
396 |
| - source_ir, |
397 |
| - f"{name}_add_beta", |
398 |
| - reshaped_output, |
399 |
| - reshaped_bias, |
400 |
| - ) |
401 |
| - if return_mean_rstd: |
402 |
| - # return fake mean and rstd for now |
403 |
| - return reshaped_output, None, None |
404 |
| - return reshaped_output |
| 245 | + if weight is not None: |
| 246 | + weight = get_trt_tensor(ctx, weight, f"{name}_weight") |
| 247 | + weight = cast_trt_tensor( |
| 248 | + ctx, weight, input.dtype, f"{name}_cast_weight", target, source_ir |
| 249 | + ) |
| 250 | + weight = impl.shuffle.reshape( |
| 251 | + ctx, target, source_ir, f"{name}_reshape_weight", weight, shape |
| 252 | + ) |
| 253 | + output = impl.elementwise.mul( |
| 254 | + ctx, target, source_ir, f"{name}_mul_weight", output, weight |
| 255 | + ) |
405 | 256 |
|
| 257 | + if bias is not None: |
| 258 | + bias = get_trt_tensor(ctx, bias, f"{name}_bias") |
| 259 | + bias = cast_trt_tensor( |
| 260 | + ctx, bias, input.dtype, f"{name}_cast_bias", target, source_ir |
| 261 | + ) |
| 262 | + bias = impl.shuffle.reshape( |
| 263 | + ctx, target, source_ir, f"{name}_reshape_bias", bias, shape |
| 264 | + ) |
| 265 | + output = impl.elementwise.add( |
| 266 | + ctx, target, source_ir, f"{name}_add_bias", output, bias |
| 267 | + ) |
406 | 268 |
|
407 |
| -def group_norm( |
408 |
| - ctx: ConversionContext, |
409 |
| - target: Target, |
410 |
| - source_ir: Optional[SourceIR], |
411 |
| - name: str, |
412 |
| - input: TRTTensor, |
413 |
| - num_groups: int, |
414 |
| - weight: Optional[Union[torch.Tensor, np.ndarray]], |
415 |
| - bias: Optional[Union[torch.Tensor, np.ndarray]], |
416 |
| - eps: float, |
417 |
| - cudnn_enabled: bool, |
418 |
| -) -> Union[TRTTensor, Sequence[TRTTensor]]: |
419 |
| - return native_group_norm( |
420 |
| - ctx, |
421 |
| - target, |
422 |
| - source_ir, |
423 |
| - name, |
424 |
| - input, |
425 |
| - weight, |
426 |
| - bias, |
427 |
| - 0, |
428 |
| - 0, |
429 |
| - 0, |
430 |
| - num_groups, |
431 |
| - eps, |
432 |
| - return_mean_rstd=False, |
433 |
| - ) |
| 269 | + # return fake mean and rstd for now |
| 270 | + return output, None, None |
434 | 271 |
|
435 | 272 |
|
436 | 273 | def softmax(
|
|
0 commit comments