Skip to content

Commit b8d189a

Browse files
committed
change function calls from nn.Module to nn.functional
1 parent 910cf06 commit b8d189a

File tree

1 file changed

+13
-7
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/activation

1 file changed

+13
-7
lines changed

py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def elu(
137137

138138
def elu_dyn_range_fn(dyn_range):
139139
return (
140-
torch.nn.ELU(dyn_range[0], alpha),
141-
torch.nn.ELU(dyn_range[1], alpha),
140+
torch.nn.functional.elu(dyn_range[0], alpha),
141+
torch.nn.functional.elu(dyn_range[1], alpha),
142142
)
143143

144144
return convert_activation(
@@ -163,7 +163,10 @@ def selu(
163163
operation_type = trt.ActivationType.SELU
164164

165165
def selu_dyn_range_fn(dyn_range):
166-
return (torch.nn.SELU(dyn_range[0]), torch.nn.SELU(dyn_range[1]))
166+
return (
167+
torch.nn.functional.selu(dyn_range[0]),
168+
torch.nn.functional.selu(dyn_range[1]),
169+
)
167170

168171
return convert_activation(
169172
network,
@@ -187,7 +190,10 @@ def softsign(
187190
operation_type = trt.ActivationType.SOFTSIGN
188191

189192
def softsign_dyn_range_fn(dyn_range):
190-
return (torch.nn.Softsign(dyn_range[0]), torch.nn.Softsign(dyn_range[1]))
193+
return (
194+
torch.nn.functional.softsign(dyn_range[0]),
195+
torch.nn.functional.softsign(dyn_range[1]),
196+
)
191197

192198
return convert_activation(
193199
network,
@@ -212,8 +218,8 @@ def softplus(
212218

213219
def softplus_dyn_range_fn(dyn_range):
214220
return (
215-
torch.nn.Softplus(dyn_range[0], beta),
216-
torch.nn.Softplus(dyn_range[1], beta),
221+
torch.nn.functional.softplus(dyn_range[0], beta),
222+
torch.nn.functional.softplus(dyn_range[1], beta),
217223
)
218224

219225
return convert_activation(
@@ -303,7 +309,7 @@ def scaled_tanh(
303309

304310
def scaled_tanh_dyn_range_fn(dyn_range):
305311
def scaled_tanh_fn(x):
306-
return alpha * torch.nn.Tanh(beta * x)
312+
return alpha * torch.nn.functional.tanh(beta * x)
307313

308314
return scaled_tanh_fn(dyn_range[0]), scaled_tanh_fn(dyn_range[1])
309315

0 commit comments

Comments
 (0)