Skip to content

Commit cc967cf

Browse files
committed
fix bugs
1 parent 903802b commit cc967cf

File tree

2 files changed

+51
-25
lines changed

2 files changed

+51
-25
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def get_ir(target: Target) -> SourceIR:
4646
return SourceIR.UNKNOWN
4747

4848

49+
@dynamo_tensorrt_converter(torch.ops.aten.native_batch_norm.default) # type: ignore[misc]
4950
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) # type: ignore[misc]
5051
def aten_ops_batch_norm(
5152
ctx: ConversionContext,
@@ -60,17 +61,18 @@ def aten_ops_batch_norm(
6061
SourceIR.ATEN,
6162
name,
6263
input=args[0],
63-
weight=args_bounds_check(args, 1, replacement=1),
64-
bias=args_bounds_check(args, 2, replacement=0),
65-
running_mean=args_bounds_check(args, 3),
66-
running_var=args_bounds_check(args, 4),
67-
training=args_bounds_check(args, 5),
68-
momentum=args_bounds_check(args, 6, replacement=0.1),
69-
eps=args_bounds_check(args, 7, replacement=1e-05),
70-
cudnn_enabled=args_bounds_check(args, 8, replacement=False),
64+
weight=args[1],
65+
bias=args[2],
66+
running_mean=args[3],
67+
running_var=args[4],
68+
training=args[5],
69+
momentum=args[6],
70+
eps=args[7],
71+
cudnn_enabled=args_bounds_check(args, 8, replacement=True),
7172
)
7273

7374

75+
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc]
7476
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) # type: ignore[misc]
7577
def aten_ops_layer_norm(
7678
ctx: ConversionContext,
@@ -86,9 +88,9 @@ def aten_ops_layer_norm(
8688
name,
8789
input=args[0],
8890
normalized_shape=args[1],
89-
weight=args_bounds_check(args, 2, replacement=1),
90-
bias=args_bounds_check(args, 3, replacement=0),
91-
eps=args_bounds_check(args, 4, replacement=1e-05),
91+
weight=args[2],
92+
bias=args[3],
93+
eps=args[4],
9294
cudnn_enable=args_bounds_check(args, 5, replacement=True),
9395
)
9496

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@ def batch_norm(
2929
source_ir: Optional[SourceIR],
3030
name: str,
3131
input: TRTTensor,
32-
weight: torch.Tensor,
33-
bias: torch.Tensor,
34-
running_mean: torch.Tensor,
35-
running_var: torch.Tensor,
36-
training: torch.Tensor,
37-
momentum: torch.Tensor,
38-
eps: List[float],
32+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
33+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
34+
running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
35+
running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
36+
training: bool,
37+
momentum: float,
38+
eps: float,
3939
cudnn_enabled: bool,
4040
) -> Union[TRTTensor, Sequence[TRTTensor]]:
4141
if not isinstance(input, TRTTensor):
@@ -47,8 +47,20 @@ def batch_norm(
4747
if has_dynamic_shape(input.shape):
4848
assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
4949

50+
if weight is None:
51+
weight = np.array(1.0)
52+
53+
if bias is None:
54+
bias = np.array(0.0)
55+
56+
if running_mean is None:
57+
running_mean = np.array(0.0)
58+
59+
if running_var is None:
60+
running_var = np.array(1.0)
61+
5062
scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt(
51-
cast(torch.Tensor, to_numpy(running_var)) + cast(float, eps)
63+
cast(torch.Tensor, to_numpy(running_var)) + eps
5264
)
5365

5466
bias = to_numpy(bias) - to_numpy(running_mean) * scale
@@ -91,9 +103,9 @@ def layer_norm(
91103
name: str,
92104
input: TRTTensor,
93105
normalized_shape: List[int],
94-
weight: torch.Tensor,
95-
bias: torch.Tensor,
96-
eps: List[float],
106+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
107+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
108+
eps: float,
97109
cudnn_enable: bool,
98110
) -> Union[TRTTensor, Sequence[TRTTensor]]:
99111
if not isinstance(input, trt.tensorrt.ITensor):
@@ -102,6 +114,12 @@ def layer_norm(
102114
"of the TensorRT region!"
103115
)
104116

117+
if weight is None:
118+
weight = np.array(1.0)
119+
120+
if bias is None:
121+
bias = np.array(0.0)
122+
105123
gamma = (
106124
weight.detach().cpu().float().numpy()
107125
if isinstance(weight, torch.Tensor)
@@ -152,16 +170,22 @@ def layer_norm_no_plugin(
152170
name: str,
153171
input: TRTTensor,
154172
normalized_shape: List[int],
155-
weight: torch.Tensor,
156-
bias: torch.Tensor,
157-
eps: List[float],
173+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
174+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
175+
eps: float,
158176
) -> Union[TRTTensor, Sequence[TRTTensor]]:
159177
if not isinstance(input, TRTTensor):
160178
raise RuntimeError(
161179
f"LayerNorm received input {input} that is not part "
162180
"of the TensorRT region!"
163181
)
164182

183+
if weight is None:
184+
weight = np.array(1.0)
185+
186+
if bias is None:
187+
bias = np.array(0.0)
188+
165189
shape = weight.shape
166190
broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape
167191
gamma = to_numpy(weight.reshape(*shape))

0 commit comments

Comments
 (0)