Skip to content

Commit 5d6cf81

Browse files
committed
chore: Ensure input arguments are based on ITensor (TRTTensor)
1 parent 3c06a9f commit 5d6cf81

File tree

2 files changed

+113
-101
lines changed

2 files changed

+113
-101
lines changed

py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py

Lines changed: 64 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
from torch_tensorrt.dynamo.conversion.impl.cat import cat
2222
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
2323
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
24+
from torch_tensorrt.dynamo.types import TRTTensor
2425
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
25-
from torch_tensorrt.fx.types import TRTTensor
26-
from torch_tensorrt.fx.utils import get_dynamic_dims
2726

2827
_LOGGER: logging.Logger = logging.getLogger(__name__)
2928

@@ -34,8 +33,8 @@ def batch_norm(
3433
source_ir: Optional[SourceIR],
3534
name: str,
3635
input: TRTTensor,
37-
weight: Optional[Union[torch.Tensor, np.ndarray]],
38-
bias: Optional[Union[torch.Tensor, np.ndarray]],
36+
weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
37+
bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
3938
running_mean: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
4039
running_var: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]],
4140
training: bool,
@@ -51,112 +50,76 @@ def batch_norm(
5150
# Save the original output shape for later use
5251
output_shape = input.shape
5352

54-
# Handle case when running_mean or running_var is TRTTensor
55-
if isinstance(running_mean, TRTTensor) or isinstance(running_var, TRTTensor):
56-
# Default values if weight, bias, running_mean, running_var are None
57-
if weight is None:
58-
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
59-
if bias is None:
60-
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
61-
if running_mean is None:
62-
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
63-
if running_var is None:
64-
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")
65-
66-
# eps_tensor for numerical stability
67-
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
68-
69-
# adjusted_var = running_var + eps
70-
adjusted_var = impl.elementwise.add(
71-
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
72-
)
53+
if weight is None:
54+
weight = get_trt_tensor(ctx, 1.0, f"{name}_weight")
55+
if bias is None:
56+
bias = get_trt_tensor(ctx, 0.0, f"{name}_bias")
57+
if running_mean is None:
58+
running_mean = get_trt_tensor(ctx, 0.0, f"{name}_running_mean")
59+
if running_var is None:
60+
running_var = get_trt_tensor(ctx, 1.0, f"{name}_running_var")
7361

74-
# sqrt_adjusted_var = sqrt(adjusted_var)
75-
sqrt_adjusted_var = impl.unary.sqrt(
76-
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
77-
)
62+
# eps_tensor for numerical stability
63+
eps_tensor = get_trt_tensor(ctx, eps, f"{name}_eps")
7864

79-
# scale = weight / sqrt_adjusted_var
80-
scale = impl.elementwise.div(
81-
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
82-
)
65+
# adjusted_var = running_var + eps
66+
adjusted_var = impl.elementwise.add(
67+
ctx, target, source_ir, f"{name}_adjusted_var", running_var, eps_tensor
68+
)
8369

84-
# scaled_running_mean = running_mean * scale
85-
scaled_running_mean = impl.elementwise.mul(
86-
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
87-
)
70+
# sqrt_adjusted_var = sqrt(adjusted_var)
71+
sqrt_adjusted_var = impl.unary.sqrt(
72+
ctx, target, source_ir, f"{name}_sqrt", adjusted_var
73+
)
8874

89-
# bias_adjusted = bias - scaled_running_mean
90-
bias_adjusted = impl.elementwise.sub(
91-
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
92-
)
75+
# scale = weight / sqrt_adjusted_var
76+
scale = impl.elementwise.div(
77+
ctx, target, source_ir, f"{name}_scale", weight, sqrt_adjusted_var
78+
)
9379

94-
# Reshape scale and bias_adjusted to match input shape for broadcasting
95-
expanded_shape = [1] * len(output_shape)
96-
expanded_shape[1] = output_shape[1] # Set channel dimension
80+
# scaled_running_mean = running_mean * scale
81+
scaled_running_mean = impl.elementwise.mul(
82+
ctx, target, source_ir, f"{name}_scaled_running_mean", running_mean, scale
83+
)
9784

98-
scale_reshape = impl.shuffle.reshape(
99-
ctx,
100-
target,
101-
source_ir,
102-
f"{name}_reshape_scale",
103-
scale,
104-
tuple(expanded_shape),
105-
)
106-
bias_adjusted_reshape = impl.shuffle.reshape(
107-
ctx,
108-
target,
109-
source_ir,
110-
f"{name}_reshape_bias",
111-
bias_adjusted,
112-
tuple(expanded_shape),
113-
)
85+
# bias_adjusted = bias - scaled_running_mean
86+
bias_adjusted = impl.elementwise.sub(
87+
ctx, target, source_ir, f"{name}_bias_adjusted", bias, scaled_running_mean
88+
)
11489

115-
# Apply the scale and bias to the input
116-
scaled_input = impl.elementwise.mul(
117-
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
118-
)
119-
output = impl.elementwise.add(
120-
ctx,
121-
target,
122-
source_ir,
123-
f"{name}_output",
124-
scaled_input,
125-
bias_adjusted_reshape,
126-
)
90+
# Reshape scale and bias_adjusted to match input shape for broadcasting
91+
expanded_shape = [1] * len(output_shape)
92+
expanded_shape[1] = output_shape[1] # Set channel dimension
12793

128-
else:
129-
# Handle the case when running_mean and running_var are not TRTTensor
130-
if weight is None:
131-
weight = 1.0
132-
if bias is None:
133-
bias = 0.0
134-
if running_mean is None:
135-
running_mean = 0.0
136-
if running_var is None:
137-
running_var = 1.0
138-
139-
scale = to_numpy(weight) / np.sqrt(to_numpy(running_var) + eps)
140-
bias = to_numpy(bias) - to_numpy(running_mean) * scale
141-
power = np.ones_like(scale)
142-
143-
# For BatchNorm1d, reshape 1d to 2d
144-
if len(output_shape) < 4:
145-
assert (
146-
len(get_dynamic_dims(output_shape)) <= 1
147-
), "BatchNorm1D with more than one dynamic dim is not currently supported."
148-
new_shape = (
149-
(output_shape[0], output_shape[1], 1, 1)
150-
if len(output_shape) == 2
151-
else (output_shape[0], output_shape[1], output_shape[2], 1)
152-
)
153-
input = impl.shuffle.reshape(
154-
ctx, target, source_ir, f"{name}_reshape_2d", input, new_shape
155-
)
94+
scale_reshape = impl.shuffle.reshape(
95+
ctx,
96+
target,
97+
source_ir,
98+
f"{name}_reshape_scale",
99+
scale,
100+
tuple(expanded_shape),
101+
)
102+
bias_adjusted_reshape = impl.shuffle.reshape(
103+
ctx,
104+
target,
105+
source_ir,
106+
f"{name}_reshape_bias",
107+
bias_adjusted,
108+
tuple(expanded_shape),
109+
)
156110

157-
layer = ctx.net.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power)
158-
set_layer_name(layer, target, name, source_ir)
159-
output = layer.get_output(0)
111+
# Apply the scale and bias to the input
112+
scaled_input = impl.elementwise.mul(
113+
ctx, target, source_ir, f"{name}_scaled_input", input, scale_reshape
114+
)
115+
output = impl.elementwise.add(
116+
ctx,
117+
target,
118+
source_ir,
119+
f"{name}_output",
120+
scaled_input,
121+
bias_adjusted_reshape,
122+
)
160123

161124
# For BatchNorm1d, reshape output back to original shape if necessary
162125
if len(output_shape) < 4:

tests/py/dynamo/conversion/test_batch_norm_aten.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,55 @@ def forward(self, x):
2929
inputs,
3030
)
3131

32+
def test_batchnorm_ITensor_weights_bias(self):
33+
class BatchNorm(torch.nn.Module):
34+
def forward(self, x, weight, bias):
35+
return torch.ops.aten.batch_norm.default(
36+
x,
37+
weight,
38+
bias,
39+
torch.zeros((FEATURE_NUM,)),
40+
torch.ones((FEATURE_NUM,)),
41+
False,
42+
0.1,
43+
1e-05,
44+
True,
45+
)
46+
47+
inputs = [
48+
torch.randn(1, 3, 224, 224),
49+
torch.ones((FEATURE_NUM,)),
50+
torch.zeros((FEATURE_NUM,)),
51+
]
52+
self.run_test(
53+
BatchNorm(),
54+
inputs,
55+
)
56+
57+
def test_batchnorm_ITensor_weights(self):
58+
class BatchNorm(torch.nn.Module):
59+
def forward(self, x, weight):
60+
return torch.ops.aten.batch_norm.default(
61+
x,
62+
weight,
63+
None,
64+
torch.zeros((FEATURE_NUM,)),
65+
torch.ones((FEATURE_NUM,)),
66+
False,
67+
0.1,
68+
1e-05,
69+
True,
70+
)
71+
72+
inputs = [
73+
torch.randn(1, 3, 224, 224),
74+
torch.ones((FEATURE_NUM,)),
75+
]
76+
self.run_test(
77+
BatchNorm(),
78+
inputs,
79+
)
80+
3281
def test_batchnorm_static_bias_only(self):
3382
class BatchNorm(torch.nn.Module):
3483
def forward(self, x):

0 commit comments

Comments
 (0)