21
21
from torch_tensorrt .dynamo .conversion .impl .cat import cat
22
22
from torch_tensorrt .dynamo .conversion .impl .elementwise .ops import ge
23
23
from torch_tensorrt .dynamo .conversion .impl .shape import shape as get_shape
24
+ from torch_tensorrt .dynamo .types import TRTTensor
24
25
from torch_tensorrt .dynamo .utils import DYNAMIC_DIM
25
- from torch_tensorrt .fx .types import TRTTensor
26
- from torch_tensorrt .fx .utils import get_dynamic_dims
27
26
28
27
_LOGGER : logging .Logger = logging .getLogger (__name__ )
29
28
@@ -34,61 +33,102 @@ def batch_norm(
34
33
source_ir : Optional [SourceIR ],
35
34
name : str ,
36
35
input : TRTTensor ,
37
- weight : Optional [Union [torch .Tensor , np .ndarray ]],
38
- bias : Optional [Union [torch .Tensor , np .ndarray ]],
39
- running_mean : Optional [Union [torch .Tensor , np .ndarray ]],
40
- running_var : Optional [Union [torch .Tensor , np .ndarray ]],
36
+ weight : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
37
+ bias : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
38
+ running_mean : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
39
+ running_var : Optional [Union [TRTTensor , torch .Tensor , np .ndarray ]],
41
40
training : bool ,
42
41
momentum : float ,
43
42
eps : float ,
44
43
cudnn_enabled : bool ,
45
44
return_mean_rstd : bool ,
46
45
) -> Union [TRTTensor , Tuple [TRTTensor , torch .Tensor , torch .Tensor ]]:
46
+
47
47
if has_dynamic_shape (input .shape ):
48
48
assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for batch norm."
49
49
50
- if weight is None :
51
- weight = 1.0
50
+ # Save the original output shape for later use
51
+ output_shape = input . shape
52
52
53
+ if weight is None :
54
+ weight = get_trt_tensor (ctx , 1.0 , f"{ name } _weight" )
53
55
if bias is None :
54
- bias = 0.0
55
-
56
+ bias = get_trt_tensor (ctx , 0.0 , f"{ name } _bias" )
56
57
if running_mean is None :
57
- running_mean = 0.0
58
-
58
+ running_mean = get_trt_tensor (ctx , 0.0 , f"{ name } _running_mean" )
59
59
if running_var is None :
60
- running_var = 1.0
60
+ running_var = get_trt_tensor ( ctx , 1.0 , f" { name } _running_var" )
61
61
62
- scale = to_numpy (weight ) / np .sqrt (to_numpy (running_var ) + eps )
63
- bias = to_numpy (bias ) - to_numpy (running_mean ) * scale
64
- power = np .ones_like (scale )
62
+ # eps_tensor for numerical stability
63
+ eps_tensor = get_trt_tensor (ctx , eps , f"{ name } _eps" )
65
64
66
- # For BatchNorm1d, reshape 1d to 2d
67
- output_shape = input .shape
68
- if len (input .shape ) < 4 :
69
- assert (
70
- len (get_dynamic_dims (input .shape )) <= 1
71
- ), "BatchNorm1D with more than one dynamic dims is not currently supported."
72
- new_shape = (
73
- (input .shape [0 ], input .shape [1 ], 1 , 1 )
74
- if len (input .shape ) == 2
75
- else (input .shape [0 ], input .shape [1 ], input .shape [2 ], 1 )
76
- )
77
- input = impl .shuffle .reshape (
78
- ctx , target , source_ir , f"{ name } _reshape_2d" , input , new_shape
79
- )
80
- layer = ctx .net .add_scale (input , trt .ScaleMode .CHANNEL , bias , scale , power )
81
- set_layer_name (layer , target , name , source_ir )
82
- output = layer .get_output (0 )
65
+ # adjusted_var = running_var + eps
66
+ adjusted_var = impl .elementwise .add (
67
+ ctx , target , source_ir , f"{ name } _adjusted_var" , running_var , eps_tensor
68
+ )
69
+
70
+ # sqrt_adjusted_var = sqrt(adjusted_var)
71
+ sqrt_adjusted_var = impl .unary .sqrt (
72
+ ctx , target , source_ir , f"{ name } _sqrt" , adjusted_var
73
+ )
74
+
75
+ # scale = weight / sqrt_adjusted_var
76
+ scale = impl .elementwise .div (
77
+ ctx , target , source_ir , f"{ name } _scale" , weight , sqrt_adjusted_var
78
+ )
79
+
80
+ # scaled_running_mean = running_mean * scale
81
+ scaled_running_mean = impl .elementwise .mul (
82
+ ctx , target , source_ir , f"{ name } _scaled_running_mean" , running_mean , scale
83
+ )
84
+
85
+ # bias_adjusted = bias - scaled_running_mean
86
+ bias_adjusted = impl .elementwise .sub (
87
+ ctx , target , source_ir , f"{ name } _bias_adjusted" , bias , scaled_running_mean
88
+ )
89
+
90
+ # Reshape scale and bias_adjusted to match input shape for broadcasting
91
+ expanded_shape = [1 ] * len (output_shape )
92
+ expanded_shape [1 ] = output_shape [1 ] # Set channel dimension
93
+
94
+ scale_reshape = impl .shuffle .reshape (
95
+ ctx ,
96
+ target ,
97
+ source_ir ,
98
+ f"{ name } _reshape_scale" ,
99
+ scale ,
100
+ tuple (expanded_shape ),
101
+ )
102
+ bias_adjusted_reshape = impl .shuffle .reshape (
103
+ ctx ,
104
+ target ,
105
+ source_ir ,
106
+ f"{ name } _reshape_bias" ,
107
+ bias_adjusted ,
108
+ tuple (expanded_shape ),
109
+ )
110
+
111
+ # Apply the scale and bias to the input
112
+ scaled_input = impl .elementwise .mul (
113
+ ctx , target , source_ir , f"{ name } _scaled_input" , input , scale_reshape
114
+ )
115
+ output = impl .elementwise .add (
116
+ ctx ,
117
+ target ,
118
+ source_ir ,
119
+ f"{ name } _output" ,
120
+ scaled_input ,
121
+ bias_adjusted_reshape ,
122
+ )
83
123
84
- # For BatchNorm1d, reshape output back to 1d
124
+ # For BatchNorm1d, reshape output back to original shape if necessary
85
125
if len (output_shape ) < 4 :
86
126
output = impl .shuffle .reshape (
87
127
ctx ,
88
128
target ,
89
129
source_ir ,
90
130
f"{ name } _reshape_1d" ,
91
- layer . get_output ( 0 ) ,
131
+ output ,
92
132
output_shape ,
93
133
)
94
134
0 commit comments