@@ -29,13 +29,13 @@ def batch_norm(
29
29
source_ir : Optional [SourceIR ],
30
30
name : str ,
31
31
input : TRTTensor ,
32
- weight : torch .Tensor ,
33
- bias : torch .Tensor ,
34
- running_mean : torch .Tensor ,
35
- running_var : torch .Tensor ,
36
- training : torch . Tensor ,
37
- momentum : torch . Tensor ,
38
- eps : List [ float ] ,
32
+ weight : Optional [ Union [ TRTTensor , torch .Tensor , np . ndarray ]] ,
33
+ bias : Optional [ Union [ TRTTensor , torch .Tensor , np . ndarray ]] ,
34
+ running_mean : Optional [ Union [ TRTTensor , torch .Tensor , np . ndarray ]] ,
35
+ running_var : Optional [ Union [ TRTTensor , torch .Tensor , np . ndarray ]] ,
36
+ training : bool ,
37
+ momentum : float ,
38
+ eps : float ,
39
39
cudnn_enabled : bool ,
40
40
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
41
41
if not isinstance (input , TRTTensor ):
@@ -47,8 +47,20 @@ def batch_norm(
47
47
if has_dynamic_shape (input .shape ):
48
48
assert input .shape [1 ] != - 1 , "Channel dim can't be dynamic for batch norm."
49
49
50
+ if weight is None :
51
+ weight = np .array (1.0 )
52
+
53
+ if bias is None :
54
+ bias = np .array (0.0 )
55
+
56
+ if running_mean is None :
57
+ running_mean = np .array (0.0 )
58
+
59
+ if running_var is None :
60
+ running_var = np .array (1.0 )
61
+
50
62
scale = cast (torch .Tensor , to_numpy (weight )) / np .sqrt (
51
- cast (torch .Tensor , to_numpy (running_var )) + cast ( float , eps )
63
+ cast (torch .Tensor , to_numpy (running_var )) + eps
52
64
)
53
65
54
66
bias = to_numpy (bias ) - to_numpy (running_mean ) * scale
@@ -91,9 +103,9 @@ def layer_norm(
91
103
name : str ,
92
104
input : TRTTensor ,
93
105
normalized_shape : List [int ],
94
- weight : torch .Tensor ,
95
- bias : torch .Tensor ,
96
- eps : List [ float ] ,
106
+ weight : Optional [ Union [ TRTTensor , torch .Tensor , np . ndarray ]] ,
107
+ bias : Optional [ Union [ TRTTensor , torch .Tensor , np . ndarray ]] ,
108
+ eps : float ,
97
109
cudnn_enable : bool ,
98
110
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
99
111
if not isinstance (input , trt .tensorrt .ITensor ):
@@ -102,6 +114,12 @@ def layer_norm(
102
114
"of the TensorRT region!"
103
115
)
104
116
117
+ if weight is None :
118
+ weight = np .array (1.0 )
119
+
120
+ if bias is None :
121
+ bias = np .array (0.0 )
122
+
105
123
gamma = (
106
124
weight .detach ().cpu ().float ().numpy ()
107
125
if isinstance (weight , torch .Tensor )
@@ -152,16 +170,22 @@ def layer_norm_no_plugin(
152
170
name : str ,
153
171
input : TRTTensor ,
154
172
normalized_shape : List [int ],
155
- weight : torch .Tensor ,
156
- bias : torch .Tensor ,
157
- eps : List [ float ] ,
173
+ weight : Optional [ Union [ TRTTensor , torch .Tensor , np . ndarray ]] ,
174
+ bias : Optional [ Union [ TRTTensor , torch .Tensor , np . ndarray ]] ,
175
+ eps : float ,
158
176
) -> Union [TRTTensor , Sequence [TRTTensor ]]:
159
177
if not isinstance (input , TRTTensor ):
160
178
raise RuntimeError (
161
179
f"LayerNorm received input { input } that is not part "
162
180
"of the TensorRT region!"
163
181
)
164
182
183
+ if weight is None :
184
+ weight = np .array (1.0 )
185
+
186
+ if bias is None :
187
+ bias = np .array (0.0 )
188
+
165
189
shape = weight .shape
166
190
broadcasted_shape = (1 ,) * (len (input .shape ) - len (shape )) + shape
167
191
gamma = to_numpy (weight .reshape (* shape ))
0 commit comments