Skip to content

Commit ba8652e

Browse files
committed
chore : minor fix
1 parent 9133df8 commit ba8652e

File tree

1 file changed

+6
-6
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/normalization

1 file changed

+6
-6
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def batch_norm(
3636
input: TRTTensor,
3737
weight: Optional[Union[torch.Tensor, np.ndarray]],
3838
bias: Optional[Union[torch.Tensor, np.ndarray]],
39-
running_mean: Union[TRTTensor, Optional[Union[torch.Tensor, np.ndarray]]],
40-
running_var: Union[TRTTensor, Optional[Union[torch.Tensor, np.ndarray]]],
39+
running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
40+
running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
4141
training: bool,
4242
momentum: float,
4343
eps: float,
@@ -55,13 +55,13 @@ def batch_norm(
5555
if isinstance(running_mean, TRTTensor) or isinstance(running_var, TRTTensor):
5656
# Default values if weight, bias, running_mean, running_var are None
5757
if weight is None:
58-
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight", input.dtype)
58+
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
5959
if bias is None:
60-
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias", input.dtype)
60+
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
6161
if running_mean is None:
62-
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean", input.dtype)
62+
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
6363
if running_var is None:
64-
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var", input.dtype)
64+
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")
6565

6666
# eps_tensor for numerical stability
6767
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")

0 commit comments

Comments
 (0)