@@ -36,59 +36,136 @@ def batch_norm(
36
36
input : TRTTensor ,
37
37
weight : Optional [Union [torch .Tensor , np .ndarray ]],
38
38
bias : Optional [Union [torch .Tensor , np .ndarray ]],
39
- running_mean : Optional [Union [torch .Tensor , np .ndarray ]],
40
- running_var : Optional [Union [torch .Tensor , np .ndarray ]],
39
+ running_mean : Union [ TRTTensor , Optional [Union [torch .Tensor , np .ndarray ] ]],
40
+ running_var : Union [ TRTTensor , Optional [Union [torch .Tensor , np .ndarray ] ]],
41
41
training : bool ,
42
42
momentum : float ,
43
43
eps : float ,
44
44
cudnn_enabled : bool ,
45
45
return_mean_rstd : bool ,
46
46
) -> Union [TRTTensor , Tuple [TRTTensor , torch .Tensor , torch .Tensor ]]:
47
+
47
48
if has_dynamic_shape (input .shape ):
48
49
assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for batch norm."
49
50
50
- if weight is None :
51
- weight = 1.0
51
+ # Save the original output shape for later use
52
+ output_shape = input . shape
52
53
53
- if bias is None :
54
- bias = 0.0
54
+ # Handle case when running_mean or running_var is TRTTensor
55
+ if isinstance (running_mean , TRTTensor ) or isinstance (running_var , TRTTensor ):
56
+ # Default values if weight, bias, running_mean, running_var are None
57
+ if weight is None :
58
+ weight = get_trt_tensor (ctx , 1.0 , f"{ name } _weight" , input .dtype )
59
+ if bias is None :
60
+ bias = get_trt_tensor (ctx , 0.0 , f"{ name } _bias" , input .dtype )
61
+ if running_mean is None :
62
+ running_mean = get_trt_tensor (ctx , 0.0 , f"{ name } _running_mean" , input .dtype )
63
+ if running_var is None :
64
+ running_var = get_trt_tensor (ctx , 1.0 , f"{ name } _running_var" , input .dtype )
65
+
66
+ # eps_tensor for numerical stability
67
+ eps_tensor = get_trt_tensor (ctx , eps , f"{ name } _eps" )
68
+
69
+ # adjusted_var = running_var + eps
70
+ adjusted_var = impl .elementwise .add (
71
+ ctx , target , source_ir , f"{ name } _adjusted_var" , running_var , eps_tensor
72
+ )
55
73
56
- if running_mean is None :
57
- running_mean = 0.0
74
+ # sqrt_adjusted_var = sqrt(adjusted_var)
75
+ sqrt_adjusted_var = impl .unary .sqrt (
76
+ ctx , target , source_ir , f"{ name } _sqrt" , adjusted_var
77
+ )
58
78
59
- if running_var is None :
60
- running_var = 1.0
79
+ # scale = weight / sqrt_adjusted_var
80
+ scale = impl .elementwise .div (
81
+ ctx , target , source_ir , f"{ name } _scale" , weight , sqrt_adjusted_var
82
+ )
61
83
62
- scale = to_numpy (weight ) / np .sqrt (to_numpy (running_var ) + eps )
63
- bias = to_numpy (bias ) - to_numpy (running_mean ) * scale
64
- power = np .ones_like (scale )
84
+ # scaled_running_mean = running_mean * scale
85
+ scaled_running_mean = impl .elementwise .mul (
86
+ ctx , target , source_ir , f"{ name } _scaled_running_mean" , running_mean , scale
87
+ )
65
88
66
- # For BatchNorm1d, reshape 1d to 2d
67
- output_shape = input .shape
68
- if len (input .shape ) < 4 :
69
- assert (
70
- len (get_dynamic_dims (input .shape )) <= 1
71
- ), "BatchNorm1D with more than one dynamic dims is not currently supported."
72
- new_shape = (
73
- (input .shape [0 ], input .shape [1 ], 1 , 1 )
74
- if len (input .shape ) == 2
75
- else (input .shape [0 ], input .shape [1 ], input .shape [2 ], 1 )
89
+ # bias_adjusted = bias - scaled_running_mean
90
+ bias_adjusted = impl .elementwise .sub (
91
+ ctx , target , source_ir , f"{ name } _bias_adjusted" , bias , scaled_running_mean
76
92
)
77
- input = impl .shuffle .reshape (
78
- ctx , target , source_ir , f"{ name } _reshape_2d" , input , new_shape
93
+
94
+ # Reshape scale and bias_adjusted to match input shape for broadcasting
95
+ expanded_shape = [1 ] * len (output_shape )
96
+ expanded_shape [1 ] = output_shape [1 ] # Set channel dimension
97
+
98
+ scale_reshape = impl .shuffle .reshape (
99
+ ctx ,
100
+ target ,
101
+ source_ir ,
102
+ f"{ name } _reshape_scale" ,
103
+ scale ,
104
+ tuple (expanded_shape ),
79
105
)
80
- layer = ctx .net .add_scale (input , trt .ScaleMode .CHANNEL , bias , scale , power )
81
- set_layer_name (layer , target , name , source_ir )
82
- output = layer .get_output (0 )
106
+ bias_adjusted_reshape = impl .shuffle .reshape (
107
+ ctx ,
108
+ target ,
109
+ source_ir ,
110
+ f"{ name } _reshape_bias" ,
111
+ bias_adjusted ,
112
+ tuple (expanded_shape ),
113
+ )
114
+
115
+ # Apply the scale and bias to the input
116
+ scaled_input = impl .elementwise .mul (
117
+ ctx , target , source_ir , f"{ name } _scaled_input" , input , scale_reshape
118
+ )
119
+ output = impl .elementwise .add (
120
+ ctx ,
121
+ target ,
122
+ source_ir ,
123
+ f"{ name } _output" ,
124
+ scaled_input ,
125
+ bias_adjusted_reshape ,
126
+ )
127
+
128
+ else :
129
+ # Handle the case when running_mean and running_var are not TRTTensor
130
+ if weight is None :
131
+ weight = 1.0
132
+ if bias is None :
133
+ bias = 0.0
134
+ if running_mean is None :
135
+ running_mean = 0.0
136
+ if running_var is None :
137
+ running_var = 1.0
138
+
139
+ scale = to_numpy (weight ) / np .sqrt (to_numpy (running_var ) + eps )
140
+ bias = to_numpy (bias ) - to_numpy (running_mean ) * scale
141
+ power = np .ones_like (scale )
142
+
143
+ # For BatchNorm1d, reshape 1d to 2d
144
+ if len (output_shape ) < 4 :
145
+ assert (
146
+ len (get_dynamic_dims (output_shape )) <= 1
147
+ ), "BatchNorm1D with more than one dynamic dim is not currently supported."
148
+ new_shape = (
149
+ (output_shape [0 ], output_shape [1 ], 1 , 1 )
150
+ if len (output_shape ) == 2
151
+ else (output_shape [0 ], output_shape [1 ], output_shape [2 ], 1 )
152
+ )
153
+ input = impl .shuffle .reshape (
154
+ ctx , target , source_ir , f"{ name } _reshape_2d" , input , new_shape
155
+ )
156
+
157
+ layer = ctx .net .add_scale (input , trt .ScaleMode .CHANNEL , bias , scale , power )
158
+ set_layer_name (layer , target , name , source_ir )
159
+ output = layer .get_output (0 )
83
160
84
- # For BatchNorm1d, reshape output back to 1d
161
+ # For BatchNorm1d, reshape output back to original shape if necessary
85
162
if len (output_shape ) < 4 :
86
163
output = impl .shuffle .reshape (
87
164
ctx ,
88
165
target ,
89
166
source_ir ,
90
167
f"{ name } _reshape_1d" ,
91
- layer . get_output ( 0 ) ,
168
+ output ,
92
169
output_shape ,
93
170
)
94
171
0 commit comments