Skip to content

Commit 32a41ac

Browse files
committed
Add support for prelu dynamo converter
1 parent 058ee5c commit 32a41ac

File tree

5 files changed

+113
-1
lines changed

5 files changed

+113
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3370,3 +3370,29 @@ def aten_ops_native_dropout(
33703370
args[1],
33713371
args_bounds_check(args, 2, None),
33723372
)
3373+
3374+
3375+
@dynamo_tensorrt_converter(
3376+
torch.ops.aten._prelu_kernel.default, supports_dynamic_shapes=True
3377+
)
3378+
@enforce_tensor_types(
3379+
{
3380+
0: (TRTTensor,),
3381+
1: (TRTTensor,),
3382+
}
3383+
)
3384+
def aten_ops_prelu(
3385+
ctx: ConversionContext,
3386+
target: Target,
3387+
args: Tuple[Argument, ...],
3388+
kwargs: Dict[str, Argument],
3389+
name: str,
3390+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3391+
return impl.prelu.prelu(
3392+
ctx,
3393+
target,
3394+
SourceIR.ATEN,
3395+
name,
3396+
args[0],
3397+
args[1],
3398+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
pad,
1919
permutation,
2020
pool,
21+
prelu,
2122
quantize,
2223
reduce,
2324
select,
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Optional
2+
3+
from torch.fx.node import Target
4+
from torch_tensorrt.dynamo._SourceIR import SourceIR
5+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
6+
from torch_tensorrt.dynamo.conversion.converter_utils import set_layer_name
7+
from torch_tensorrt.dynamo.types import TRTTensor
8+
9+
10+
def prelu(
11+
ctx: ConversionContext,
12+
target: Target,
13+
source_ir: Optional[SourceIR],
14+
name: str,
15+
input: TRTTensor,
16+
weight: TRTTensor,
17+
) -> TRTTensor:
18+
layer = ctx.net.add_parametric_relu(input, weight)
19+
set_layer_name(layer, target, name, source_ir)
20+
return layer.get_output(0)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@
108108
aten.norm,
109109
aten.ones,
110110
aten.ones_like,
111-
aten._prelu_kernel,
112111
aten._prelu_kernel_backward,
113112
aten._reshape_alias,
114113
aten.rad2deg,
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestPReLUConverter(DispatchTestCase):
10+
def test_prelu(self):
11+
class TestModule(nn.Module):
12+
def forward(self, x, weight):
13+
return torch.ops.aten._prelu_kernel.default(x, weight)
14+
15+
inputs = [torch.randn(1, 10), torch.randn(1, 1)]
16+
self.run_test(TestModule(), inputs)
17+
18+
def test_prelu_with_dynamic_shape(self):
19+
class TestModule(nn.Module):
20+
def forward(self, x, weight):
21+
return torch.ops.aten._prelu_kernel.default(x, weight)
22+
23+
input_specs = [
24+
Input(
25+
min_shape=(1, 1, 1),
26+
opt_shape=(1, 2, 3),
27+
max_shape=(3, 3, 3),
28+
dtype=torch.float32,
29+
name="x",
30+
),
31+
Input(
32+
min_shape=(1, 1, 1),
33+
opt_shape=(1, 2, 1),
34+
max_shape=(1, 3, 1),
35+
dtype=torch.float32,
36+
name="weight",
37+
),
38+
]
39+
self.run_test_with_dynamic_shape(TestModule(), input_specs)
40+
41+
def test_prelu_with_dynamic_shape_four_dimensions(self):
42+
class TestModule(nn.Module):
43+
def forward(self, x, weight):
44+
return torch.ops.aten._prelu_kernel.default(x, weight)
45+
46+
input_specs = [
47+
Input(
48+
min_shape=(1, 1, 1, 5),
49+
opt_shape=(1, 2, 3, 5),
50+
max_shape=(3, 3, 3, 5),
51+
dtype=torch.float32,
52+
name="x",
53+
),
54+
Input(
55+
min_shape=(1, 1, 1, 5),
56+
opt_shape=(1, 2, 3, 5),
57+
max_shape=(3, 3, 3, 5),
58+
dtype=torch.float32,
59+
name="weight",
60+
),
61+
]
62+
self.run_test_with_dynamic_shape(TestModule(), input_specs)
63+
64+
65+
if __name__ == "__main__":
66+
run_tests()

0 commit comments

Comments
 (0)