@@ -137,8 +137,8 @@ def elu(
137
137
138
138
def elu_dyn_range_fn (dyn_range ):
139
139
return (
140
- torch .nn .ELU (dyn_range [0 ], alpha ),
141
- torch .nn .ELU (dyn_range [1 ], alpha ),
140
+ torch .nn .functional . elu (dyn_range [0 ], alpha ),
141
+ torch .nn .functional . elu (dyn_range [1 ], alpha ),
142
142
)
143
143
144
144
return convert_activation (
@@ -163,7 +163,10 @@ def selu(
163
163
operation_type = trt .ActivationType .SELU
164
164
165
165
def selu_dyn_range_fn (dyn_range ):
166
- return (torch .nn .SELU (dyn_range [0 ]), torch .nn .SELU (dyn_range [1 ]))
166
+ return (
167
+ torch .nn .functional .selu (dyn_range [0 ]),
168
+ torch .nn .functional .selu (dyn_range [1 ]),
169
+ )
167
170
168
171
return convert_activation (
169
172
network ,
@@ -187,7 +190,10 @@ def softsign(
187
190
operation_type = trt .ActivationType .SOFTSIGN
188
191
189
192
def softsign_dyn_range_fn (dyn_range ):
190
- return (torch .nn .Softsign (dyn_range [0 ]), torch .nn .Softsign (dyn_range [1 ]))
193
+ return (
194
+ torch .nn .functional .softsign (dyn_range [0 ]),
195
+ torch .nn .functional .softsign (dyn_range [1 ]),
196
+ )
191
197
192
198
return convert_activation (
193
199
network ,
@@ -212,8 +218,8 @@ def softplus(
212
218
213
219
def softplus_dyn_range_fn (dyn_range ):
214
220
return (
215
- torch .nn .Softplus (dyn_range [0 ], beta ),
216
- torch .nn .Softplus (dyn_range [1 ], beta ),
221
+ torch .nn .functional . softplus (dyn_range [0 ], beta ),
222
+ torch .nn .functional . softplus (dyn_range [1 ], beta ),
217
223
)
218
224
219
225
return convert_activation (
@@ -303,7 +309,7 @@ def scaled_tanh(
303
309
304
310
def scaled_tanh_dyn_range_fn (dyn_range ):
305
311
def scaled_tanh_fn (x ):
306
- return alpha * torch .nn .Tanh (beta * x )
312
+ return alpha * torch .nn .functional . tanh (beta * x )
307
313
308
314
return scaled_tanh_fn (dyn_range [0 ]), scaled_tanh_fn (dyn_range [1 ])
309
315
0 commit comments