Skip to content

Commit a335af8

Browse files
apbosegs-olive
authored andcommitted
correcting nn_ops for leaky_relu and correcting linting error
1 parent fab3d79 commit a335af8

File tree

5 files changed

+10
-21
lines changed

5 files changed

+10
-21
lines changed

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,12 +1025,7 @@ def acc_ops_leaky_relu(
10251025
) -> Union[TRTTensor, Sequence[TRTTensor]]:
10261026

10271027
return activation.leaky_relu(
1028-
network,
1029-
target,
1030-
SourceIR.ACC,
1031-
name,
1032-
kwargs["input"],
1033-
kwargs["negative_slope"]
1028+
network, target, SourceIR.ACC, name, kwargs["input"], kwargs["negative_slope"]
10341029
)
10351030

10361031

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def aten_ops_fmod(
216216
return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name)
217217

218218

219-
220219
@tensorrt_converter(torch.ops.aten.leaky_relu.default)
221220
def aten_ops_leaky_relu(
222221
network: TRTNetwork,
@@ -226,14 +225,7 @@ def aten_ops_leaky_relu(
226225
name: str,
227226
) -> Union[TRTTensor, Sequence[TRTTensor]]:
228227

229-
return activation.leaky_relu(
230-
network,
231-
target,
232-
SourceIR.ATEN,
233-
name,
234-
args[0],
235-
args[1]
236-
)
228+
return activation.leaky_relu(network, target, SourceIR.ATEN, name, args[0], args[1])
237229

238230

239231
@tensorrt_converter(torch.ops.aten.linear)

py/torch_tensorrt/fx/converters/impl/activation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,14 @@ def leaky_relu(
9898
source_ir: Optional[SourceIR],
9999
name: str,
100100
input_val: TRTTensor,
101-
alpha: Optional[Any]
101+
alpha: Optional[Any],
102102
):
103103
operation_type = trt.ActivationType.LEAKY_RELU
104104

105105
def leaky_relu_dyn_range_fn(dyn_range):
106-
return (max(0, dyn_range[0]) + alpha * min(0, dyn_range[0])), (max(0, dyn_range[1]) + alpha * min(0, dyn_range[1]))
106+
return (max(0, dyn_range[0]) + alpha * min(0, dyn_range[0])), (
107+
max(0, dyn_range[1]) + alpha * min(0, dyn_range[1])
108+
)
107109

108110
return convert_activation(
109111
network,

py/torch_tensorrt/fx/converters/nn_ops_converters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ def relu(network, submod, args, kwargs, layer_name):
3030
# args/kwargs should have already been normalized to kwargs
3131
assert len(args) == 0
3232

33-
return activation.relu(
33+
return activation.leaky_relu(
3434
network=network,
3535
target="torch.nn.functional.leaky_relu",
3636
source_ir=SourceIR.NN,
3737
name=layer_name,
3838
input_val=kwargs["input"],
39-
alpha=kwargs["negative_slope"]
40-
)
39+
alpha=kwargs["negative_slope"],
40+
)

py/torch_tensorrt/fx/test/converters/aten_op/test_leaky_relu_aten.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ def forward(self, x):
5050

5151

5252
if __name__ == "__main__":
53-
run_tests()
53+
run_tests()

0 commit comments

Comments
 (0)