Skip to content

Commit 2da25ad

Browse files
committed
update group_norm with native ops
1 parent 794e2bb commit 2da25ad

File tree

1 file changed

+163
-104
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/normalization

1 file changed

+163
-104
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 163 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
)
1414
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
1515
from torch_tensorrt.fx.converters.converter_utils import (
16-
get_trt_plugin,
17-
get_trt_tensor,
16+
get_positive_dim,
1817
has_dynamic_shape,
1918
set_layer_name,
2019
)
@@ -53,10 +52,7 @@ def native_batch_norm(
5352
if running_var is None:
5453
running_var = 1.0
5554

56-
scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt(
57-
cast(torch.Tensor, to_numpy(running_var)) + eps
58-
)
59-
55+
scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps)
6056
bias = to_numpy(bias) - to_numpy(running_mean) * scale
6157
power = np.ones_like(scale)
6258

@@ -135,78 +131,6 @@ def layer_norm(
135131
eps: float,
136132
cudnn_enable: bool,
137133
) -> Union[TRTTensor, Sequence[TRTTensor]]:
138-
if not isinstance(input, trt.tensorrt.ITensor):
139-
raise RuntimeError(
140-
f"LayerNorm received input {input} that is not part "
141-
"of the TensorRT region!"
142-
)
143-
144-
if weight is None:
145-
weight = to_numpy(1.0)
146-
147-
if bias is None:
148-
bias = to_numpy(0.0)
149-
150-
gamma = (
151-
weight.detach().cpu().float().numpy()
152-
if isinstance(weight, torch.Tensor)
153-
else weight
154-
)
155-
gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32)
156-
beta = (
157-
bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias
158-
)
159-
beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32)
160-
eps_field = trt.PluginField(
161-
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32
162-
)
163-
try:
164-
normalized_shape_arr = np.array(normalized_shape, dtype=np.int32)
165-
except TypeError:
166-
_LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
167-
normalized_shape_arr = np.array([], dtype=np.int32)
168-
169-
normalized_shape_filed = trt.PluginField(
170-
"normalized_shape", normalized_shape_arr, trt.PluginFieldType.INT32
171-
)
172-
field_collection = trt.PluginFieldCollection(
173-
[gamma_field, beta_field, eps_field, normalized_shape_filed]
174-
)
175-
176-
try:
177-
if ctx.net.has_implicit_batch_dimension:
178-
plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt")
179-
else:
180-
plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt")
181-
except AssertionError:
182-
_LOGGER.error(
183-
"Unable to find layer norm plugin, fall back to TensorRT implementation."
184-
)
185-
return layer_norm_no_plugin(
186-
ctx, target, source_ir, name, input, normalized_shape, weight, bias, eps
187-
)
188-
layer = ctx.net.add_plugin_v2([input], plugin)
189-
layer.name = name
190-
return layer.get_output(0)
191-
192-
193-
def layer_norm_no_plugin(
194-
ctx: ConversionContext,
195-
target: Target,
196-
source_ir: Optional[SourceIR],
197-
name: str,
198-
input: TRTTensor,
199-
normalized_shape: List[int],
200-
weight: Optional[Union[torch.Tensor, np.ndarray]],
201-
bias: Optional[Union[torch.Tensor, np.ndarray]],
202-
eps: float,
203-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
204-
if not isinstance(input, TRTTensor):
205-
raise RuntimeError(
206-
f"LayerNorm received input {input} that is not part "
207-
"of the TensorRT region!"
208-
)
209-
210134
if weight is None:
211135
weight = to_numpy(1.0)
212136

@@ -357,45 +281,180 @@ def group_norm(
357281
eps: float,
358282
cudnn_enabled: bool,
359283
) -> Union[TRTTensor, Sequence[TRTTensor]]:
360-
if not isinstance(input, trt.tensorrt.ITensor):
361-
raise RuntimeError(
362-
f"LayerNorm received input {input} that is not part "
363-
"of the TensorRT region!"
364-
)
365-
366284
if weight is None:
367285
weight = to_numpy(1.0)
368286

369287
if bias is None:
370288
bias = to_numpy(0.0)
371289

372-
scale = get_trt_tensor(network, weight, "scale")
373-
bias = get_trt_tensor(network, bias, "bias")
290+
assert (
291+
len(input.shape) >= 3
292+
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
293+
B, C = input.shape[0], input.shape[1]
374294

375-
eps_field = trt.PluginField(
376-
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32
295+
# Groups are a subdivision of the channel dimension.
296+
assert (
297+
C % num_groups == 0
298+
), f"The num of channels ({C}) should be divisible by num_groups ({num_groups})!"
299+
300+
# Normalize every group.
301+
reshaped_input = impl.shuffle.reshape(
302+
network,
303+
target,
304+
SourceIR.ATEN,
305+
name,
306+
input,
307+
shape=(B * num_groups, -1),
377308
)
378-
num_groups_filed = trt.PluginField(
379-
"num_groups", np.array(num_groups), trt.PluginFieldType.INT32
309+
dim = (
310+
len(reshaped_input.shape) - 1
311+
) # TODO: PR #2347 supported negtive dimension in reduce, could be -1
312+
313+
# E[X]
314+
mean_trt = impl.reduce.mean(
315+
network,
316+
target,
317+
SourceIR.ATEN,
318+
f"{name}_mean",
319+
reshaped_input,
320+
dim=dim,
321+
keepdim=True,
380322
)
381323

382-
field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed])
324+
# X - E[X]
325+
sub_trt = impl.elementwise.sub(
326+
network,
327+
target,
328+
source_ir,
329+
f"{name}_sub",
330+
reshaped_input,
331+
mean_trt,
332+
)
383333

384-
try:
385-
# Here's the schema of the plugin:
386-
# https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml
387-
plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1")
388-
except AssertionError:
389-
_LOGGER.error(
390-
"Unable to find group norm plugin, fall back to TensorRT implementation."
391-
)
334+
# variance = mean(pow(sub_trt, 2))
335+
pow_layer = network.add_constant(
336+
(1,) * len(sub_trt.shape),
337+
trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)),
338+
)
339+
pow_layer.name = f"{name}_power"
340+
341+
pow_var = impl.elementwise.pow(
342+
network,
343+
target,
344+
source_ir,
345+
f"{name}_pow",
346+
sub_trt,
347+
pow_layer.get_output(0),
348+
)
349+
350+
var_trt = impl.reduce.mean(
351+
network,
352+
target,
353+
SourceIR.ATEN,
354+
f"{name}_mean_var",
355+
pow_var,
356+
dim=dim,
357+
keepdim=True,
358+
)
359+
360+
# sqrt((var + eps))
361+
eps_layer = network.add_constant(
362+
(1,) * len(reshaped_input.shape),
363+
trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)),
364+
)
365+
eps_layer.name = f"{name}_eps"
366+
367+
add_trt = impl.elementwise.add(
368+
network,
369+
target,
370+
source_ir,
371+
f"{name}_add",
372+
var_trt,
373+
eps_layer.get_output(0),
374+
)
375+
sqrt_trt = impl.unary.sqrt(
376+
network,
377+
target,
378+
source_ir,
379+
f"{name}_sqrt",
380+
add_trt,
381+
)
382+
383+
# (X - E[X]) / sqrt((var + eps))
384+
div_trt = impl.elementwise.div(
385+
network,
386+
target,
387+
source_ir,
388+
f"{name}_div",
389+
sub_trt,
390+
sqrt_trt,
391+
)
392+
393+
# Apply per-channel scale and bias.
394+
output = impl.shuffle.reshape(
395+
network,
396+
target,
397+
SourceIR.ATEN,
398+
f"{name}_reshape_div",
399+
div_trt,
400+
shape=input.shape,
401+
)
402+
403+
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
404+
405+
reshaped_weight = impl.shuffle.reshape(
406+
network,
407+
target,
408+
SourceIR.ATEN,
409+
f"{name}_reshape_weight",
410+
weight,
411+
shape=weight_bias_shape,
412+
)
413+
414+
output = impl.elementwise.mul(
415+
network,
416+
target,
417+
SourceIR.ATEN,
418+
f"{name}_mul_scale",
419+
output,
420+
reshaped_weight,
421+
)
392422

393-
layer = network.add_plugin_v2([input, scale, bias], plugin)
394-
set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir)
423+
reshaped_bias = impl.shuffle.reshape(
424+
network,
425+
target,
426+
SourceIR.ATEN,
427+
f"{name}_reshape_bias",
428+
bias,
429+
shape=weight_bias_shape,
430+
)
431+
432+
add_trt = impl.elementwise.add(
433+
network,
434+
target,
435+
source_ir,
436+
f"{name}_add_bias",
437+
output,
438+
reshaped_bias,
439+
)
440+
441+
# TODO: compute the last two return values
442+
# const1_layer = network.add_constant(
443+
# (1,) * len(sqrt_trt.shape),
444+
# trt.Weights(np.ascontiguousarray([1.0], dtype=np.float32)),
445+
# )
446+
# const1_layer.name = f"{name}_const1"
447+
448+
# rsqrt_trt = impl.elementwise.div(
449+
# network,
450+
# target,
451+
# source_ir,
452+
# f"{name}_rsqrt",
453+
# const1_layer.get_output(0),
454+
# sqrt_trt,
455+
# )
395456

396-
# PyTorch requires three return values: (out, mean, rstd)
397-
dummy_tensor = torch.tensor(0)
398-
return layer.get_output(0), dummy_tensor, dummy_tensor
457+
return add_trt, torch.tensor(0), torch.tensor(0)
399458

400459

401460
def softmax(

0 commit comments

Comments
 (0)