Skip to content

Commit 66511da

Browse files
authored
feat: Add handling for ITensor mean and var in batch_norm (#3099)
1 parent 9a08cc7 commit 66511da

File tree

2 files changed

+174
-41
lines changed

2 files changed

+174
-41
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 75 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
from torch_tensorrt.dynamo.conversion.impl.cat import cat
2222
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
2323
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
24+
from torch_tensorrt.dynamo.types import TRTTensor
2425
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
25-
from torch_tensorrt.fx.types import TRTTensor
26-
from torch_tensorrt.fx.utils import get_dynamic_dims
2726

2827
_LOGGER: logging.Logger = logging.getLogger(__name__)
2928

@@ -34,61 +33,102 @@ def batch_norm(
3433
source_ir: Optional[SourceIR],
3534
name: str,
3635
input: TRTTensor,
37-
weight: Optional[Union[torch.Tensor, np.ndarray]],
38-
bias: Optional[Union[torch.Tensor, np.ndarray]],
39-
running_mean: Optional[Union[torch.Tensor, np.ndarray]],
40-
running_var: Optional[Union[torch.Tensor, np.ndarray]],
36+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
37+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
38+
running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
39+
running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
4140
training: bool,
4241
momentum: float,
4342
eps: float,
4443
cudnn_enabled: bool,
4544
return_mean_rstd: bool,
4645
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
46+
4747
if has_dynamic_shape(input.shape):
4848
assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
4949

50-
if weight is None:
51-
weight = 1.0
50+
# Save the original output shape for later use
51+
output_shape = input.shape
5252

53+
if weight is None:
54+
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
5355
if bias is None:
54-
bias = 0.0
55-
56+
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
5657
if running_mean is None:
57-
running_mean = 0.0
58-
58+
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
5959
if running_var is None:
60-
running_var = 1.0
60+
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")
6161

62-
scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps)
63-
bias = to_numpy(bias) - to_numpy(running_mean) * scale
64-
power = np.ones_like(scale)
62+
# eps_tensor for numerical stability
63+
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
6564

66-
# For BatchNorm1d, reshape 1d to 2d
67-
output_shape = input.shape
68-
if len(input.shape) < 4:
69-
assert (
70-
len(get_dynamic_dims(input.shape)) <= 1
71-
), "BatchNorm1D with more than one dynamic dims is not currently supported."
72-
new_shape = (
73-
(input.shape[0], input.shape[1], 1, 1)
74-
if len(input.shape) == 2
75-
else (input.shape[0], input.shape[1], input.shape[2], 1)
76-
)
77-
input = impl.shuffle.reshape(
78-
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
79-
)
80-
layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power)
81-
set_layer_name(layer, target, name, source_ir)
82-
output = layer.get_output(0)
65+
# adjusted_var = running_var + eps
66+
adjusted_var = impl.elementwise.add(
67+
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
68+
)
69+
70+
# sqrt_adjusted_var = sqrt(adjusted_var)
71+
sqrt_adjusted_var = impl.unary.sqrt(
72+
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
73+
)
74+
75+
# scale = weight / sqrt_adjusted_var
76+
scale = impl.elementwise.div(
77+
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
78+
)
79+
80+
# scaled_running_mean = running_mean * scale
81+
scaled_running_mean = impl.elementwise.mul(
82+
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
83+
)
84+
85+
# bias_adjusted = bias - scaled_running_mean
86+
bias_adjusted = impl.elementwise.sub(
87+
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
88+
)
89+
90+
# Reshape scale and bias_adjusted to match input shape for broadcasting
91+
expanded_shape = [1] * len(output_shape)
92+
expanded_shape[1] = output_shape[1] # Set channel dimension
93+
94+
scale_reshape = impl.shuffle.reshape(
95+
ctx,
96+
target,
97+
source_ir,
98+
f"{name}_reshape_scale",
99+
scale,
100+
tuple(expanded_shape),
101+
)
102+
bias_adjusted_reshape = impl.shuffle.reshape(
103+
ctx,
104+
target,
105+
source_ir,
106+
f"{name}_reshape_bias",
107+
bias_adjusted,
108+
tuple(expanded_shape),
109+
)
110+
111+
# Apply the scale and bias to the input
112+
scaled_input = impl.elementwise.mul(
113+
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
114+
)
115+
output = impl.elementwise.add(
116+
ctx,
117+
target,
118+
source_ir,
119+
f"{name}_output",
120+
scaled_input,
121+
bias_adjusted_reshape,
122+
)
83123

84-
# For BatchNorm1d, reshape output back to 1d
124+
# For BatchNorm1d, reshape output back to original shape if necessary
85125
if len(output_shape) < 4:
86126
output = impl.shuffle.reshape(
87127
ctx,
88128
target,
89129
source_ir,
90130
f"{name}_reshape_1d",
91-
layer.get_output(0),
131+
output,
92132
output_shape,
93133
)
94134

tests/py/dynamo/conversion/test_batch_norm_aten.py

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,82 @@
88

99

1010
class TestBatchNormConverter(DispatchTestCase):
11-
def test_batchnorm(self):
11+
def test_batchnorm_static_weights(self):
1212
class BatchNorm(torch.nn.Module):
1313
def forward(self, x):
1414
return torch.ops.aten.batch_norm.default(
1515
x,
16+
torch.full((FEATURE_NUM,), 3, dtype=torch.float32),
17+
torch.zeros((FEATURE_NUM,)),
18+
torch.zeros((FEATURE_NUM,)),
19+
torch.full((FEATURE_NUM,), 3, dtype=torch.float32),
20+
False,
21+
0.1,
22+
1e-05,
23+
True,
24+
)
25+
26+
inputs = [torch.randn(1, 3, 224, 224)]
27+
self.run_test(
28+
BatchNorm(),
29+
inputs,
30+
)
31+
32+
def test_batchnorm_ITensor_weights_bias(self):
33+
class BatchNorm(torch.nn.Module):
34+
def forward(self, x, weight, bias):
35+
return torch.ops.aten.batch_norm.default(
36+
x,
37+
weight,
38+
bias,
39+
torch.zeros((FEATURE_NUM,)),
1640
torch.ones((FEATURE_NUM,)),
41+
False,
42+
0.1,
43+
1e-05,
44+
True,
45+
)
46+
47+
inputs = [
48+
torch.randn(1, 3, 224, 224),
49+
torch.ones((FEATURE_NUM,)),
50+
torch.zeros((FEATURE_NUM,)),
51+
]
52+
self.run_test(
53+
BatchNorm(),
54+
inputs,
55+
)
56+
57+
def test_batchnorm_ITensor_weights(self):
58+
class BatchNorm(torch.nn.Module):
59+
def forward(self, x, weight):
60+
return torch.ops.aten.batch_norm.default(
61+
x,
62+
weight,
63+
None,
64+
torch.zeros((FEATURE_NUM,)),
65+
torch.ones((FEATURE_NUM,)),
66+
False,
67+
0.1,
68+
1e-05,
69+
True,
70+
)
71+
72+
inputs = [
73+
torch.randn(1, 3, 224, 224),
74+
torch.ones((FEATURE_NUM,)),
75+
]
76+
self.run_test(
77+
BatchNorm(),
78+
inputs,
79+
)
80+
81+
def test_batchnorm_static_bias_only(self):
82+
class BatchNorm(torch.nn.Module):
83+
def forward(self, x):
84+
return torch.ops.aten.batch_norm.default(
85+
x,
86+
None,
1787
torch.zeros((FEATURE_NUM,)),
1888
torch.zeros((FEATURE_NUM,)),
1989
torch.ones((FEATURE_NUM,)),
@@ -57,7 +127,7 @@ def forward(self, x):
57127
input_specs,
58128
)
59129

60-
def test_batchnorm_with_dynamic_shape(self):
130+
def test_batchnorm2d_with_dynamic_shape(self):
61131
class BatchNorm(torch.nn.Module):
62132
def forward(self, x):
63133
return torch.ops.aten.batch_norm.default(
@@ -87,7 +157,7 @@ def forward(self, x):
87157

88158

89159
class TestNativeBatchNormConverter(DispatchTestCase):
90-
def test_batchnorm(self):
160+
def test_native_batchnorm_static_weights(self):
91161
class BatchNorm(torch.nn.Module):
92162
def forward(self, x):
93163
return torch.ops.aten.native_batch_norm.default(
@@ -107,7 +177,30 @@ def forward(self, x):
107177
inputs,
108178
)
109179

110-
def test_batchnorm_legit_no_training(self):
180+
def test_native_batchnorm_legit_no_training_with_trt_tensor(self):
181+
class BatchNorm(torch.nn.Module):
182+
def forward(self, x, running_mean, running_var):
183+
return torch.ops.aten._native_batch_norm_legit_no_training.default(
184+
x,
185+
torch.ones((FEATURE_NUM,)),
186+
torch.zeros((FEATURE_NUM,)),
187+
running_mean,
188+
running_var,
189+
0.1,
190+
1e-05,
191+
)[0]
192+
193+
inputs = [
194+
torch.randn(1, 3, 224, 224),
195+
torch.zeros((FEATURE_NUM,)),
196+
torch.ones((FEATURE_NUM,)),
197+
]
198+
self.run_test(
199+
BatchNorm(),
200+
inputs,
201+
)
202+
203+
def test_native_batchnorm_legit_no_training_with_static_means(self):
111204
class BatchNorm(torch.nn.Module):
112205
def forward(self, x):
113206
return torch.ops.aten._native_batch_norm_legit_no_training.default(
@@ -126,7 +219,7 @@ def forward(self, x):
126219
inputs,
127220
)
128221

129-
def test_batchnorm1d_with_dynamic_shape(self):
222+
def test_native_batchnorm1d_with_dynamic_shape(self):
130223
class BatchNorm(torch.nn.Module):
131224
def forward(self, x):
132225
return torch.ops.aten.native_batch_norm.default(
@@ -153,7 +246,7 @@ def forward(self, x):
153246
input_specs,
154247
)
155248

156-
def test_batchnorm_with_dynamic_shape(self):
249+
def test_native_batchnorm2d_with_dynamic_shape(self):
157250
class BatchNorm(torch.nn.Module):
158251
def forward(self, x):
159252
return torch.ops.aten.native_batch_norm.default(

0 commit comments

Comments
 (0)