Skip to content

Commit d978db6

Browse files
committed
fix bugs
1 parent dce9526 commit d978db6

File tree

2 files changed

+51
-25
lines changed

2 files changed

+51
-25
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def args_bounds_check(
1818
return args[i] if len(args) > i else replacement
1919

2020

21+
@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default) # type: ignore[misc]
2122
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc]
2223
def aten_ops_batch_norm(
2324
network: TRTNetwork,
@@ -32,17 +33,18 @@ def aten_ops_batch_norm(
3233
SourceIR.ATEN,
3334
name,
3435
input=args[0],
35-
weight=args_bounds_check(args, 1, replacement=1),
36-
bias=args_bounds_check(args, 2, replacement=0),
37-
running_mean=args_bounds_check(args, 3),
38-
running_var=args_bounds_check(args, 4),
39-
training=args_bounds_check(args, 5),
40-
momentum=args_bounds_check(args, 6, replacement=0.1),
41-
eps=args_bounds_check(args, 7, replacement=1e-05),
42-
cudnn_enabled=args_bounds_check(args, 8, replacement=False),
36+
weight=args[1],
37+
bias=args[2],
38+
running_mean=args[3],
39+
running_var=args[4],
40+
training=args[5],
41+
momentum=args[6],
42+
eps=args[7],
43+
cudnn_enabled=args_bounds_check(args, 8, replacement=True),
4344
)
4445

4546

47+
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc]
4648
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
4749
def aten_ops_layer_norm(
4850
network: TRTNetwork,
@@ -58,9 +60,9 @@ def aten_ops_layer_norm(
5860
name,
5961
input=args[0],
6062
normalized_shape=args[1],
61-
weight=args_bounds_check(args, 2, replacement=1),
62-
bias=args_bounds_check(args, 3, replacement=0),
63-
eps=args_bounds_check(args, 4, replacement=1e-05),
63+
weight=args[2],
64+
bias=args[3],
65+
eps=args[4],
6466
cudnn_enable=args_bounds_check(args, 5, replacement=True),
6567
)
6668

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ def batch_norm(
2929
source_ir: Optional[SourceIR],
3030
name: str,
3131
input: TRTTensor,
32-
weight: torch.Tensor,
33-
bias: torch.Tensor,
34-
running_mean: torch.Tensor,
35-
running_var: torch.Tensor,
36-
training: torch.Tensor,
37-
momentum: torch.Tensor,
38-
eps: List[float],
32+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
33+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
34+
running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
35+
running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
36+
training: bool,
37+
momentum: float,
38+
eps: float,
3939
cudnn_enabled: bool,
4040
) -> Union[TRTTensor, Sequence[TRTTensor]]:
4141
if not isinstance(input, TRTTensor):
@@ -47,8 +47,20 @@ def batch_norm(
4747
if has_dynamic_shape(input.shape):
4848
assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
4949

50+
if weight is None:
51+
weight = np.array(1.0)
52+
53+
if bias is None:
54+
bias = np.array(0.0)
55+
56+
if running_mean is None:
57+
running_mean = np.array(0.0)
58+
59+
if running_var is None:
60+
running_var = np.array(1.0)
61+
5062
scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt(
51-
cast(torch.Tensor, to_numpy(running_var)) + cast(float, eps)
63+
cast(torch.Tensor, to_numpy(running_var)) + eps
5264
)
5365

5466
bias = to_numpy(bias) - to_numpy(running_mean) * scale
@@ -91,9 +103,9 @@ def layer_norm(
91103
name: str,
92104
input: TRTTensor,
93105
normalized_shape: List[int],
94-
weight: torch.Tensor,
95-
bias: torch.Tensor,
96-
eps: List[float],
106+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
107+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
108+
eps: float,
97109
cudnn_enable: bool,
98110
) -> Union[TRTTensor, Sequence[TRTTensor]]:
99111
if not isinstance(input, trt.tensorrt.ITensor):
@@ -102,6 +114,12 @@ def layer_norm(
102114
"of the TensorRT region!"
103115
)
104116

117+
if weight is None:
118+
weight = np.array(1.0)
119+
120+
if bias is None:
121+
bias = np.array(0.0)
122+
105123
gamma = (
106124
weight.detach().cpu().float().numpy()
107125
if isinstance(weight, torch.Tensor)
@@ -152,16 +170,22 @@ def layer_norm_no_plugin(
152170
name: str,
153171
input: TRTTensor,
154172
normalized_shape: List[int],
155-
weight: torch.Tensor,
156-
bias: torch.Tensor,
157-
eps: List[float],
173+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
174+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
175+
eps: float,
158176
) -> Union[TRTTensor, Sequence[TRTTensor]]:
159177
if not isinstance(input, TRTTensor):
160178
raise RuntimeError(
161179
f"LayerNorm received input {input} that is not part "
162180
"of the TensorRT region!"
163181
)
164182

183+
if weight is None:
184+
weight = np.array(1.0)
185+
186+
if bias is None:
187+
bias = np.array(0.0)
188+
165189
shape = weight.shape
166190
broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape
167191
gamma = to_numpy(weight.reshape(*shape))

0 commit comments

Comments
 (0)