Skip to content

Commit 414d972

Browse files
committed
Fixed batchnorm bug
1 parent e4e4d31 commit 414d972

File tree

2 files changed

+31
-12
lines changed

2 files changed

+31
-12
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919

2020
import numpy as np
21+
import tensorrt as trt
2122
import torch
2223
import torch.fx
2324
from torch.fx.node import _get_qualified_name
@@ -43,7 +44,6 @@
4344
from torch_tensorrt.fx.observer import Observer
4445
from torch_tensorrt.logging import TRT_LOGGER
4546

46-
import tensorrt as trt
4747
from packaging import version
4848

4949
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -472,12 +472,18 @@ def _save_weight_mapping(self) -> None:
472472
# Retrieve each weight name(s) in state_dict
473473
if layer_type == "CONSTANT":
474474
if "embedding" in suffix:
475-
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
475+
sd_weight_name = f"{sd_weight_name}.weight"
476476
elif "weight" in suffix or "mm_other" in suffix:
477477
# Linear layer weight
478-
sd_weight_name = f"{sd_weight_name}.{torch_attr[0]}"
478+
sd_weight_name = f"{sd_weight_name}.weight"
479+
elif "running_mean" in suffix:
480+
# Linear layer weight
481+
sd_weight_name = f"{sd_weight_name}.running_mean"
482+
elif "running_var" in suffix:
483+
# Linear layer weight
484+
sd_weight_name = f"{sd_weight_name}.running_var"
479485
else:
480-
sd_weight_name = f"{sd_weight_name}.{torch_attr[1]}"
486+
sd_weight_name = f"{sd_weight_name}.bias"
481487
elif layer_type == "SCALE":
482488
# Batch norm needs all weights to calculate scale and shift
483489
sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr]

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,27 @@ def batch_norm(
5050
# Save the original output shape for later use
5151
output_shape = input.shape
5252

53-
if weight is None:
54-
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
55-
if bias is None:
56-
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
57-
if running_mean is None:
58-
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
59-
if running_var is None:
60-
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")
53+
# We name the weight here according to the state_dict name
54+
weight = (
55+
get_trt_tensor(ctx, 1.0, f"{name}_weight")
56+
if weight is None
57+
else get_trt_tensor(ctx, weight, f"{name}_weight")
58+
)
59+
bias = (
60+
get_trt_tensor(ctx, 1.0, f"{name}_bias")
61+
if bias is None
62+
else get_trt_tensor(ctx, bias, f"{name}_bias")
63+
)
64+
running_mean = (
65+
get_trt_tensor(ctx, 1.0, f"{name}_running_mean")
66+
if running_mean is None
67+
else get_trt_tensor(ctx, running_mean, f"{name}_running_mean")
68+
)
69+
running_var = (
70+
get_trt_tensor(ctx, 1.0, f"{name}_running_var")
71+
if running_var is None
72+
else get_trt_tensor(ctx, running_var, f"{name}_running_var")
73+
)
6174

6275
# eps_tensor for numerical stability
6376
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")

0 commit comments

Comments
 (0)