Skip to content

Commit 8a41cf9

Browse files
committed
update group_norm with native ops
1 parent a4634f7 commit 8a41cf9

File tree

1 file changed

+162
-104
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/normalization

1 file changed

+162
-104
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 162 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from torch_tensorrt.dynamo.conversion.converter_utils import get_axes_for_reduce_op
1111
from torch_tensorrt.fx.converters.converter_utils import (
1212
get_positive_dim,
13-
get_trt_plugin,
14-
get_trt_tensor,
1513
has_dynamic_shape,
1614
set_layer_name,
1715
to_numpy,
@@ -58,10 +56,7 @@ def batch_norm(
5856
if running_var is None:
5957
running_var = 1.0
6058

61-
scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt(
62-
cast(torch.Tensor, to_numpy(running_var)) + eps
63-
)
64-
59+
scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps)
6560
bias = to_numpy(bias) - to_numpy(running_mean) * scale
6661
power = np.ones_like(scale)
6762

@@ -107,78 +102,6 @@ def layer_norm(
107102
eps: float,
108103
cudnn_enable: bool,
109104
) -> Union[TRTTensor, Sequence[TRTTensor]]:
110-
if not isinstance(input, trt.tensorrt.ITensor):
111-
raise RuntimeError(
112-
f"LayerNorm received input {input} that is not part "
113-
"of the TensorRT region!"
114-
)
115-
116-
if weight is None:
117-
weight = to_numpy(1.0)
118-
119-
if bias is None:
120-
bias = to_numpy(0.0)
121-
122-
gamma = (
123-
weight.detach().cpu().float().numpy()
124-
if isinstance(weight, torch.Tensor)
125-
else weight
126-
)
127-
gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32)
128-
beta = (
129-
bias.detach().cpu().float().numpy() if isinstance(bias, torch.Tensor) else bias
130-
)
131-
beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32)
132-
eps_field = trt.PluginField(
133-
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32
134-
)
135-
try:
136-
normalized_shape_arr = np.array(normalized_shape, dtype=np.int32)
137-
except TypeError:
138-
_LOGGER.error("Unable to convert normalized_shape to a field, fall back to []")
139-
normalized_shape_arr = np.array([], dtype=np.int32)
140-
141-
normalized_shape_filed = trt.PluginField(
142-
"normalized_shape", normalized_shape_arr, trt.PluginFieldType.INT32
143-
)
144-
field_collection = trt.PluginFieldCollection(
145-
[gamma_field, beta_field, eps_field, normalized_shape_filed]
146-
)
147-
148-
try:
149-
if network.has_implicit_batch_dimension:
150-
plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt")
151-
else:
152-
plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt")
153-
except AssertionError:
154-
_LOGGER.error(
155-
"Unable to find layer norm plugin, fall back to TensorRT implementation."
156-
)
157-
return layer_norm_no_plugin(
158-
network, target, source_ir, name, input, normalized_shape, weight, bias, eps
159-
)
160-
layer = network.add_plugin_v2([input], plugin)
161-
layer.name = name
162-
return layer.get_output(0)
163-
164-
165-
def layer_norm_no_plugin(
166-
network: TRTNetwork,
167-
target: Target,
168-
source_ir: Optional[SourceIR],
169-
name: str,
170-
input: TRTTensor,
171-
normalized_shape: List[int],
172-
weight: Optional[Union[torch.Tensor, np.ndarray]],
173-
bias: Optional[Union[torch.Tensor, np.ndarray]],
174-
eps: float,
175-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
176-
if not isinstance(input, TRTTensor):
177-
raise RuntimeError(
178-
f"LayerNorm received input {input} that is not part "
179-
"of the TensorRT region!"
180-
)
181-
182105
if weight is None:
183106
weight = to_numpy(1.0)
184107

@@ -333,45 +256,180 @@ def group_norm(
333256
eps: float,
334257
cudnn_enabled: bool,
335258
) -> Union[TRTTensor, Sequence[TRTTensor]]:
336-
if not isinstance(input, trt.tensorrt.ITensor):
337-
raise RuntimeError(
338-
f"LayerNorm received input {input} that is not part "
339-
"of the TensorRT region!"
340-
)
341-
342259
if weight is None:
343260
weight = to_numpy(1.0)
344261

345262
if bias is None:
346263
bias = to_numpy(0.0)
347264

348-
scale = get_trt_tensor(network, weight, "scale")
349-
bias = get_trt_tensor(network, bias, "bias")
265+
assert (
266+
len(input.shape) >= 3
267+
), f"The input dimension should not be less than 3, got {len(input.shape)}!"
268+
B, C = input.shape[0], input.shape[1]
350269

351-
eps_field = trt.PluginField(
352-
"eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32
270+
# Groups are a subdivision of the channel dimension.
271+
assert (
272+
C % num_groups == 0
273+
), f"The num of channels ({C}) should be divisible by num_groups ({num_groups})!"
274+
275+
# Normalize every group.
276+
reshaped_input = impl.shuffle.reshape(
277+
network,
278+
target,
279+
SourceIR.ATEN,
280+
name,
281+
input,
282+
shape=(B * num_groups, -1),
353283
)
354-
num_groups_filed = trt.PluginField(
355-
"num_groups", np.array(num_groups), trt.PluginFieldType.INT32
284+
dim = (
285+
len(reshaped_input.shape) - 1
286+
) # TODO: PR #2347 supported negtive dimension in reduce, could be -1
287+
288+
# E[X]
289+
mean_trt = impl.reduce.mean(
290+
network,
291+
target,
292+
SourceIR.ATEN,
293+
f"{name}_mean",
294+
reshaped_input,
295+
dim=dim,
296+
keepdim=True,
356297
)
357298

358-
field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed])
299+
# X - E[X]
300+
sub_trt = impl.elementwise.sub(
301+
network,
302+
target,
303+
source_ir,
304+
f"{name}_sub",
305+
reshaped_input,
306+
mean_trt,
307+
)
359308

360-
try:
361-
# Here's the schema of the plugin:
362-
# https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml
363-
plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1")
364-
except AssertionError:
365-
_LOGGER.error(
366-
"Unable to find group norm plugin, fall back to TensorRT implementation."
367-
)
309+
# variance = mean(pow(sub_trt, 2))
310+
pow_layer = network.add_constant(
311+
(1,) * len(sub_trt.shape),
312+
trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)),
313+
)
314+
pow_layer.name = f"{name}_power"
368315

369-
layer = network.add_plugin_v2([input, scale, bias], plugin)
370-
set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir)
316+
pow_var = impl.elementwise.pow(
317+
network,
318+
target,
319+
source_ir,
320+
f"{name}_pow",
321+
sub_trt,
322+
pow_layer.get_output(0),
323+
)
324+
325+
var_trt = impl.reduce.mean(
326+
network,
327+
target,
328+
SourceIR.ATEN,
329+
f"{name}_mean_var",
330+
pow_var,
331+
dim=dim,
332+
keepdim=True,
333+
)
334+
335+
# sqrt((var + eps))
336+
eps_layer = network.add_constant(
337+
(1,) * len(reshaped_input.shape),
338+
trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)),
339+
)
340+
eps_layer.name = f"{name}_eps"
341+
342+
add_trt = impl.elementwise.add(
343+
network,
344+
target,
345+
source_ir,
346+
f"{name}_add",
347+
var_trt,
348+
eps_layer.get_output(0),
349+
)
350+
sqrt_trt = impl.unary.sqrt(
351+
network,
352+
target,
353+
source_ir,
354+
f"{name}_sqrt",
355+
add_trt,
356+
)
357+
358+
# (X - E[X]) / sqrt((var + eps))
359+
div_trt = impl.elementwise.div(
360+
network,
361+
target,
362+
source_ir,
363+
f"{name}_div",
364+
sub_trt,
365+
sqrt_trt,
366+
)
367+
368+
# Apply per-channel scale and bias.
369+
output = impl.shuffle.reshape(
370+
network,
371+
target,
372+
SourceIR.ATEN,
373+
f"{name}_reshape_div",
374+
div_trt,
375+
shape=input.shape,
376+
)
377+
378+
weight_bias_shape = (1, C) + (1,) * (len(input.shape) - 2)
379+
380+
reshaped_weight = impl.shuffle.reshape(
381+
network,
382+
target,
383+
SourceIR.ATEN,
384+
f"{name}_reshape_weight",
385+
weight,
386+
shape=weight_bias_shape,
387+
)
388+
389+
output = impl.elementwise.mul(
390+
network,
391+
target,
392+
SourceIR.ATEN,
393+
f"{name}_mul_scale",
394+
output,
395+
reshaped_weight,
396+
)
397+
398+
reshaped_bias = impl.shuffle.reshape(
399+
network,
400+
target,
401+
SourceIR.ATEN,
402+
f"{name}_reshape_bias",
403+
bias,
404+
shape=weight_bias_shape,
405+
)
406+
407+
add_trt = impl.elementwise.add(
408+
network,
409+
target,
410+
source_ir,
411+
f"{name}_add_bias",
412+
output,
413+
reshaped_bias,
414+
)
371415

372-
# PyTorch requires three return values: (out, mean, rstd)
373-
dummy_tensor = torch.tensor(0)
374-
return layer.get_output(0), dummy_tensor, dummy_tensor
416+
# TODO: compute the last two return values
417+
# const1_layer = network.add_constant(
418+
# (1,) * len(sqrt_trt.shape),
419+
# trt.Weights(np.ascontiguousarray([1.0], dtype=np.float32)),
420+
# )
421+
# const1_layer.name = f"{name}_const1"
422+
423+
# rsqrt_trt = impl.elementwise.div(
424+
# network,
425+
# target,
426+
# source_ir,
427+
# f"{name}_rsqrt",
428+
# const1_layer.get_output(0),
429+
# sqrt_trt,
430+
# )
431+
432+
return add_trt, torch.tensor(0), torch.tensor(0)
375433

376434

377435
def softmax(

0 commit comments

Comments
 (0)