Skip to content

Commit 9dc5e5d

Browse files
committed
fix type bug
1 parent d978db6 commit 9dc5e5d

File tree

1 file changed

+8
-8
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/normalization

1 file changed

+8
-8
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,16 @@ def batch_norm(
4848
assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
4949

5050
if weight is None:
51-
weight = np.array(1.0)
51+
weight = 1.0
5252

5353
if bias is None:
54-
bias = np.array(0.0)
54+
bias = 0.0
5555

5656
if running_mean is None:
57-
running_mean = np.array(0.0)
57+
running_mean = 0.0
5858

5959
if running_var is None:
60-
running_var = np.array(1.0)
60+
running_var = 1.0
6161

6262
scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt(
6363
cast(torch.Tensor, to_numpy(running_var)) + eps
@@ -115,10 +115,10 @@ def layer_norm(
115115
)
116116

117117
if weight is None:
118-
weight = np.array(1.0)
118+
weight = to_numpy(1.0)
119119

120120
if bias is None:
121-
bias = np.array(0.0)
121+
bias = to_numpy(0.0)
122122

123123
gamma = (
124124
weight.detach().cpu().float().numpy()
@@ -181,10 +181,10 @@ def layer_norm_no_plugin(
181181
)
182182

183183
if weight is None:
184-
weight = np.array(1.0)
184+
weight = to_numpy(1.0)
185185

186186
if bias is None:
187-
bias = np.array(0.0)
187+
bias = to_numpy(0.0)
188188

189189
shape = weight.shape
190190
broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape

0 commit comments

Comments
 (0)