Skip to content

Commit e1773b4

Browse files
committed
feat: Add handling for ITensor mean and var in batch_norm
1 parent 03092ba commit e1773b4

File tree

2 files changed

+158
-37
lines changed

2 files changed

+158
-37
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 107 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,59 +36,136 @@ def batch_norm(
3636
input: TRTTensor,
3737
weight: Optional[Union[torch.Tensor, np.ndarray]],
3838
bias: Optional[Union[torch.Tensor, np.ndarray]],
39-
running_mean: Optional[Union[torch.Tensor, np.ndarray]],
40-
running_var: Optional[Union[torch.Tensor, np.ndarray]],
39+
running_mean: Union[TRTTensor, Optional[Union[torch.Tensor, np.ndarray]]],
40+
running_var: Union[TRTTensor, Optional[Union[torch.Tensor, np.ndarray]]],
4141
training: bool,
4242
momentum: float,
4343
eps: float,
4444
cudnn_enabled: bool,
4545
return_mean_rstd: bool,
4646
) -> Union[TRTTensor, Tuple[TRTTensor, torch.Tensor, torch.Tensor]]:
47+
4748
if has_dynamic_shape(input.shape):
4849
assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm."
4950

50-
if weight is None:
51-
weight = 1.0
51+
# Save the original output shape for later use
52+
output_shape = input.shape
5253

53-
if bias is None:
54-
bias = 0.0
54+
# Handle case when running_mean or running_var is TRTTensor
55+
if isinstance(running_mean, TRTTensor) or isinstance(running_var, TRTTensor):
56+
# Default values if weight, bias, running_mean, running_var are None
57+
if weight is None:
58+
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight", input.dtype)
59+
if bias is None:
60+
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias", input.dtype)
61+
if running_mean is None:
62+
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean", input.dtype)
63+
if running_var is None:
64+
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var", input.dtype)
65+
66+
# eps_tensor for numerical stability
67+
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
68+
69+
# adjusted_var = running_var + eps
70+
adjusted_var = impl.elementwise.add(
71+
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
72+
)
5573

56-
if running_mean is None:
57-
running_mean = 0.0
74+
# sqrt_adjusted_var = sqrt(adjusted_var)
75+
sqrt_adjusted_var = impl.unary.sqrt(
76+
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
77+
)
5878

59-
if running_var is None:
60-
running_var = 1.0
79+
# scale = weight / sqrt_adjusted_var
80+
scale = impl.elementwise.div(
81+
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
82+
)
6183

62-
scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps)
63-
bias = to_numpy(bias) - to_numpy(running_mean) * scale
64-
power = np.ones_like(scale)
84+
# scaled_running_mean = running_mean * scale
85+
scaled_running_mean = impl.elementwise.mul(
86+
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
87+
)
6588

66-
# For BatchNorm1d, reshape 1d to 2d
67-
output_shape = input.shape
68-
if len(input.shape) < 4:
69-
assert (
70-
len(get_dynamic_dims(input.shape)) <= 1
71-
), "BatchNorm1D with more than one dynamic dims is not currently supported."
72-
new_shape = (
73-
(input.shape[0], input.shape[1], 1, 1)
74-
if len(input.shape) == 2
75-
else (input.shape[0], input.shape[1], input.shape[2], 1)
89+
# bias_adjusted = bias - scaled_running_mean
90+
bias_adjusted = impl.elementwise.sub(
91+
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
7692
)
77-
input = impl.shuffle.reshape(
78-
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
93+
94+
# Reshape scale and bias_adjusted to match input shape for broadcasting
95+
expanded_shape = [1] * len(output_shape)
96+
expanded_shape[1] = output_shape[1] # Set channel dimension
97+
98+
scale_reshape = impl.shuffle.reshape(
99+
ctx,
100+
target,
101+
source_ir,
102+
f"{name}_reshape_scale",
103+
scale,
104+
tuple(expanded_shape),
79105
)
80-
layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power)
81-
set_layer_name(layer, target, name, source_ir)
82-
output = layer.get_output(0)
106+
bias_adjusted_reshape = impl.shuffle.reshape(
107+
ctx,
108+
target,
109+
source_ir,
110+
f"{name}_reshape_bias",
111+
bias_adjusted,
112+
tuple(expanded_shape),
113+
)
114+
115+
# Apply the scale and bias to the input
116+
scaled_input = impl.elementwise.mul(
117+
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
118+
)
119+
output = impl.elementwise.add(
120+
ctx,
121+
target,
122+
source_ir,
123+
f"{name}_output",
124+
scaled_input,
125+
bias_adjusted_reshape,
126+
)
127+
128+
else:
129+
# Handle the case when running_mean and running_var are not TRTTensor
130+
if weight is None:
131+
weight = 1.0
132+
if bias is None:
133+
bias = 0.0
134+
if running_mean is None:
135+
running_mean = 0.0
136+
if running_var is None:
137+
running_var = 1.0
138+
139+
scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps)
140+
bias = to_numpy(bias) - to_numpy(running_mean) * scale
141+
power = np.ones_like(scale)
142+
143+
# For BatchNorm1d, reshape 1d to 2d
144+
if len(output_shape) < 4:
145+
assert (
146+
len(get_dynamic_dims(output_shape)) <= 1
147+
), "BatchNorm1D with more than one dynamic dim is not currently supported."
148+
new_shape = (
149+
(output_shape[0], output_shape[1], 1, 1)
150+
if len(output_shape) == 2
151+
else (output_shape[0], output_shape[1], output_shape[2], 1)
152+
)
153+
input = impl.shuffle.reshape(
154+
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
155+
)
156+
157+
layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power)
158+
set_layer_name(layer, target, name, source_ir)
159+
output = layer.get_output(0)
83160

84-
# For BatchNorm1d, reshape output back to 1d
161+
# For BatchNorm1d, reshape output back to original shape if necessary
85162
if len(output_shape) < 4:
86163
output = impl.shuffle.reshape(
87164
ctx,
88165
target,
89166
source_ir,
90167
f"{name}_reshape_1d",
91-
layer.get_output(0),
168+
output,
92169
output_shape,
93170
)
94171

tests/py/dynamo/conversion/test_batch_norm_aten.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,33 @@
88

99

1010
class TestBatchNormConverter(DispatchTestCase):
11-
def test_batchnorm(self):
11+
def test_batchnorm_static_weights(self):
1212
class BatchNorm(torch.nn.Module):
1313
def forward(self, x):
1414
return torch.ops.aten.batch_norm.default(
1515
x,
16-
torch.ones((FEATURE_NUM,)),
16+
torch.full((FEATURE_NUM,), 3, dtype=torch.float32),
17+
torch.zeros((FEATURE_NUM,)),
18+
torch.zeros((FEATURE_NUM,)),
19+
torch.full((FEATURE_NUM,), 3, dtype=torch.float32),
20+
False,
21+
0.1,
22+
1e-05,
23+
True,
24+
)
25+
26+
inputs = [torch.randn(1, 3, 224, 224)]
27+
self.run_test(
28+
BatchNorm(),
29+
inputs,
30+
)
31+
32+
def test_batchnorm_static_bias_only(self):
33+
class BatchNorm(torch.nn.Module):
34+
def forward(self, x):
35+
return torch.ops.aten.batch_norm.default(
36+
x,
37+
None,
1738
torch.zeros((FEATURE_NUM,)),
1839
torch.zeros((FEATURE_NUM,)),
1940
torch.ones((FEATURE_NUM,)),
@@ -57,7 +78,7 @@ def forward(self, x):
5778
input_specs,
5879
)
5980

60-
def test_batchnorm_with_dynamic_shape(self):
81+
def test_batchnorm2d_with_dynamic_shape(self):
6182
class BatchNorm(torch.nn.Module):
6283
def forward(self, x):
6384
return torch.ops.aten.batch_norm.default(
@@ -87,7 +108,7 @@ def forward(self, x):
87108

88109

89110
class TestNativeBatchNormConverter(DispatchTestCase):
90-
def test_batchnorm(self):
111+
def test_native_batchnorm_static_weights(self):
91112
class BatchNorm(torch.nn.Module):
92113
def forward(self, x):
93114
return torch.ops.aten.native_batch_norm.default(
@@ -107,7 +128,30 @@ def forward(self, x):
107128
inputs,
108129
)
109130

110-
def test_batchnorm_legit_no_training(self):
131+
def test_native_batchnorm_legit_no_training_with_trt_tensor(self):
132+
class BatchNorm(torch.nn.Module):
133+
def forward(self, x, running_mean, running_var):
134+
return torch.ops.aten._native_batch_norm_legit_no_training.default(
135+
x,
136+
torch.ones((FEATURE_NUM,)),
137+
torch.zeros((FEATURE_NUM,)),
138+
running_mean,
139+
running_var,
140+
0.1,
141+
1e-05,
142+
)[0]
143+
144+
inputs = [
145+
torch.randn(1, 3, 224, 224),
146+
torch.zeros((FEATURE_NUM,)),
147+
torch.ones((FEATURE_NUM,)),
148+
]
149+
self.run_test(
150+
BatchNorm(),
151+
inputs,
152+
)
153+
154+
def test_native_batchnorm_legit_no_training_with_static_means(self):
111155
class BatchNorm(torch.nn.Module):
112156
def forward(self, x):
113157
return torch.ops.aten._native_batch_norm_legit_no_training.default(
@@ -126,7 +170,7 @@ def forward(self, x):
126170
inputs,
127171
)
128172

129-
def test_batchnorm1d_with_dynamic_shape(self):
173+
def test_native_batchnorm1d_with_dynamic_shape(self):
130174
class BatchNorm(torch.nn.Module):
131175
def forward(self, x):
132176
return torch.ops.aten.native_batch_norm.default(
@@ -153,7 +197,7 @@ def forward(self, x):
153197
input_specs,
154198
)
155199

156-
def test_batchnorm_with_dynamic_shape(self):
200+
def test_native_batchnorm2d_with_dynamic_shape(self):
157201
class BatchNorm(torch.nn.Module):
158202
def forward(self, x):
159203
return torch.ops.aten.native_batch_norm.default(

0 commit comments

Comments
 (0)