21
21
from torch_tensorrt .dynamo .conversion .impl .cat import cat
22
22
from torch_tensorrt .dynamo .conversion .impl .elementwise .ops import ge
23
23
from torch_tensorrt .dynamo .conversion .impl .shape import shape as get_shape
24
+ from torch_tensorrt .dynamo .types import TRTTensor
24
25
from torch_tensorrt .dynamo .utils import DYNAMIC_DIM
25
- from torch_tensorrt .fx .types import TRTTensor
26
- from torch_tensorrt .fx .utils import get_dynamic_dims
27
26
28
27
_LOGGER : logging .Logger = logging .getLogger (__name__ )
29
28
@@ -34,8 +33,8 @@ def batch_norm(
34
33
source_ir : Optional [SourceIR ],
35
34
name : str ,
36
35
input : TRTTensor ,
37
- weight : Optional [Union [torch .Tensor , np .ndarray ]],
38
- bias : Optional [Union [torch .Tensor , np .ndarray ]],
36
+ weight : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
37
+ bias : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
39
38
running_mean : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
40
39
running_var : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
41
40
training : bool ,
@@ -51,112 +50,76 @@ def batch_norm(
51
50
# Save the original output shape for later use
52
51
output_shape = input .shape
53
52
54
- # Handle case when running_mean or running_var is TRTTensor
55
- if isinstance (running_mean , TRTTensor ) or isinstance (running_var , TRTTensor ):
56
- # Default values if weight, bias, running_mean, running_var are None
57
- if weight is None :
58
- weight = get_trt_tensor (ctx , 1.0 , f"{ name } _weight" )
59
- if bias is None :
60
- bias = get_trt_tensor (ctx , 0.0 , f"{ name } _bias" )
61
- if running_mean is None :
62
- running_mean = get_trt_tensor (ctx , 0.0 , f"{ name } _running_mean" )
63
- if running_var is None :
64
- running_var = get_trt_tensor (ctx , 1.0 , f"{ name } _running_var" )
65
-
66
- # eps_tensor for numerical stability
67
- eps_tensor = get_trt_tensor (ctx , eps , f"{ name } _eps" )
68
-
69
- # adjusted_var = running_var + eps
70
- adjusted_var = impl .elementwise .add (
71
- ctx , target , source_ir , f"{ name } _adjusted_var" , running_var , eps_tensor
72
- )
53
+ if weight is None :
54
+ weight = get_trt_tensor (ctx , 1.0 , f"{ name } _weight" )
55
+ if bias is None :
56
+ bias = get_trt_tensor (ctx , 0.0 , f"{ name } _bias" )
57
+ if running_mean is None :
58
+ running_mean = get_trt_tensor (ctx , 0.0 , f"{ name } _running_mean" )
59
+ if running_var is None :
60
+ running_var = get_trt_tensor (ctx , 1.0 , f"{ name } _running_var" )
73
61
74
- # sqrt_adjusted_var = sqrt(adjusted_var)
75
- sqrt_adjusted_var = impl .unary .sqrt (
76
- ctx , target , source_ir , f"{ name } _sqrt" , adjusted_var
77
- )
62
+ # eps_tensor for numerical stability
63
+ eps_tensor = get_trt_tensor (ctx , eps , f"{ name } _eps" )
78
64
79
- # scale = weight / sqrt_adjusted_var
80
- scale = impl .elementwise .div (
81
- ctx , target , source_ir , f"{ name } _scale " , weight , sqrt_adjusted_var
82
- )
65
+ # adjusted_var = running_var + eps
66
+ adjusted_var = impl .elementwise .add (
67
+ ctx , target , source_ir , f"{ name } _adjusted_var " , running_var , eps_tensor
68
+ )
83
69
84
- # scaled_running_mean = running_mean * scale
85
- scaled_running_mean = impl .elementwise . mul (
86
- ctx , target , source_ir , f"{ name } _scaled_running_mean " , running_mean , scale
87
- )
70
+ # sqrt_adjusted_var = sqrt(adjusted_var)
71
+ sqrt_adjusted_var = impl .unary . sqrt (
72
+ ctx , target , source_ir , f"{ name } _sqrt " , adjusted_var
73
+ )
88
74
89
- # bias_adjusted = bias - scaled_running_mean
90
- bias_adjusted = impl .elementwise .sub (
91
- ctx , target , source_ir , f"{ name } _bias_adjusted " , bias , scaled_running_mean
92
- )
75
+ # scale = weight / sqrt_adjusted_var
76
+ scale = impl .elementwise .div (
77
+ ctx , target , source_ir , f"{ name } _scale " , weight , sqrt_adjusted_var
78
+ )
93
79
94
- # Reshape scale and bias_adjusted to match input shape for broadcasting
95
- expanded_shape = [1 ] * len (output_shape )
96
- expanded_shape [1 ] = output_shape [1 ] # Set channel dimension
80
+ # scaled_running_mean = running_mean * scale
81
+ scaled_running_mean = impl .elementwise .mul (
82
+ ctx , target , source_ir , f"{ name } _scaled_running_mean" , running_mean , scale
83
+ )
97
84
98
- scale_reshape = impl .shuffle .reshape (
99
- ctx ,
100
- target ,
101
- source_ir ,
102
- f"{ name } _reshape_scale" ,
103
- scale ,
104
- tuple (expanded_shape ),
105
- )
106
- bias_adjusted_reshape = impl .shuffle .reshape (
107
- ctx ,
108
- target ,
109
- source_ir ,
110
- f"{ name } _reshape_bias" ,
111
- bias_adjusted ,
112
- tuple (expanded_shape ),
113
- )
85
+ # bias_adjusted = bias - scaled_running_mean
86
+ bias_adjusted = impl .elementwise .sub (
87
+ ctx , target , source_ir , f"{ name } _bias_adjusted" , bias , scaled_running_mean
88
+ )
114
89
115
- # Apply the scale and bias to the input
116
- scaled_input = impl .elementwise .mul (
117
- ctx , target , source_ir , f"{ name } _scaled_input" , input , scale_reshape
118
- )
119
- output = impl .elementwise .add (
120
- ctx ,
121
- target ,
122
- source_ir ,
123
- f"{ name } _output" ,
124
- scaled_input ,
125
- bias_adjusted_reshape ,
126
- )
90
+ # Reshape scale and bias_adjusted to match input shape for broadcasting
91
+ expanded_shape = [1 ] * len (output_shape )
92
+ expanded_shape [1 ] = output_shape [1 ] # Set channel dimension
127
93
128
- else :
129
- # Handle the case when running_mean and running_var are not TRTTensor
130
- if weight is None :
131
- weight = 1.0
132
- if bias is None :
133
- bias = 0.0
134
- if running_mean is None :
135
- running_mean = 0.0
136
- if running_var is None :
137
- running_var = 1.0
138
-
139
- scale = to_numpy (weight ) / np .sqrt (to_numpy (running_var ) + eps )
140
- bias = to_numpy (bias ) - to_numpy (running_mean ) * scale
141
- power = np .ones_like (scale )
142
-
143
- # For BatchNorm1d, reshape 1d to 2d
144
- if len (output_shape ) < 4 :
145
- assert (
146
- len (get_dynamic_dims (output_shape )) <= 1
147
- ), "BatchNorm1D with more than one dynamic dim is not currently supported."
148
- new_shape = (
149
- (output_shape [0 ], output_shape [1 ], 1 , 1 )
150
- if len (output_shape ) == 2
151
- else (output_shape [0 ], output_shape [1 ], output_shape [2 ], 1 )
152
- )
153
- input = impl .shuffle .reshape (
154
- ctx , target , source_ir , f"{ name } _reshape_2d" , input , new_shape
155
- )
94
+ scale_reshape = impl .shuffle .reshape (
95
+ ctx ,
96
+ target ,
97
+ source_ir ,
98
+ f"{ name } _reshape_scale" ,
99
+ scale ,
100
+ tuple (expanded_shape ),
101
+ )
102
+ bias_adjusted_reshape = impl .shuffle .reshape (
103
+ ctx ,
104
+ target ,
105
+ source_ir ,
106
+ f"{ name } _reshape_bias" ,
107
+ bias_adjusted ,
108
+ tuple (expanded_shape ),
109
+ )
156
110
157
- layer = ctx .net .add_scale (input , trt .ScaleMode .CHANNEL , bias , scale , power )
158
- set_layer_name (layer , target , name , source_ir )
159
- output = layer .get_output (0 )
111
+ # Apply the scale and bias to the input
112
+ scaled_input = impl .elementwise .mul (
113
+ ctx , target , source_ir , f"{ name } _scaled_input" , input , scale_reshape
114
+ )
115
+ output = impl .elementwise .add (
116
+ ctx ,
117
+ target ,
118
+ source_ir ,
119
+ f"{ name } _output" ,
120
+ scaled_input ,
121
+ bias_adjusted_reshape ,
122
+ )
160
123
161
124
# For BatchNorm1d, reshape output back to original shape if necessary
162
125
if len (output_shape ) < 4 :
0 commit comments