Skip to content

Commit 15430ca

Browse files
HolyWucehongwang
authored andcommitted
Add support for prelu dynamo converter (#2972)
1 parent 46495d8 commit 15430ca

File tree

5 files changed

+113
-1
lines changed

5 files changed

+113
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3395,3 +3395,29 @@ def aten_ops_native_dropout(
33953395
args[1],
33963396
args_bounds_check(args, 2, None),
33973397
)
3398+
3399+
3400+
@dynamo_tensorrt_converter(
3401+
torch.ops.aten._prelu_kernel.default, supports_dynamic_shapes=True
3402+
)
3403+
@enforce_tensor_types(
3404+
{
3405+
0: (TRTTensor,),
3406+
1: (TRTTensor,),
3407+
}
3408+
)
3409+
def aten_ops_prelu(
3410+
ctx: ConversionContext,
3411+
target: Target,
3412+
args: Tuple[Argument, ...],
3413+
kwargs: Dict[str, Argument],
3414+
name: str,
3415+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
3416+
return impl.prelu.prelu(
3417+
ctx,
3418+
target,
3419+
SourceIR.ATEN,
3420+
name,
3421+
args[0],
3422+
args[1],
3423+
)

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
pad,
1919
permutation,
2020
pool,
21+
prelu,
2122
quantize,
2223
reduce,
2324
select,
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from typing import Optional
2+
3+
from torch.fx.node import Target
4+
from torch_tensorrt.dynamo._SourceIR import SourceIR
5+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
6+
from torch_tensorrt.dynamo.conversion.converter_utils import set_layer_name
7+
from torch_tensorrt.dynamo.types import TRTTensor
8+
9+
10+
def prelu(
11+
ctx: ConversionContext,
12+
target: Target,
13+
source_ir: Optional[SourceIR],
14+
name: str,
15+
input: TRTTensor,
16+
weight: TRTTensor,
17+
) -> TRTTensor:
18+
layer = ctx.net.add_parametric_relu(input, weight)
19+
set_layer_name(layer, target, name, source_ir)
20+
return layer.get_output(0)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@
108108
aten.norm,
109109
aten.ones,
110110
aten.ones_like,
111-
aten._prelu_kernel,
112111
aten._prelu_kernel_backward,
113112
aten._reshape_alias,
114113
aten.rad2deg,
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import torch
2+
import torch.nn as nn
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestPReLUConverter(DispatchTestCase):
10+
def test_prelu(self):
11+
class TestModule(nn.Module):
12+
def forward(self, x, weight):
13+
return torch.ops.aten._prelu_kernel.default(x, weight)
14+
15+
inputs = [torch.randn(1, 10), torch.randn(1, 1)]
16+
self.run_test(TestModule(), inputs)
17+
18+
def test_prelu_with_dynamic_shape(self):
19+
class TestModule(nn.Module):
20+
def forward(self, x, weight):
21+
return torch.ops.aten._prelu_kernel.default(x, weight)
22+
23+
input_specs = [
24+
Input(
25+
min_shape=(1, 1, 1),
26+
opt_shape=(1, 2, 3),
27+
max_shape=(3, 3, 3),
28+
dtype=torch.float32,
29+
name="x",
30+
),
31+
Input(
32+
min_shape=(1, 1, 1),
33+
opt_shape=(1, 1, 1),
34+
max_shape=(1, 1, 1),
35+
dtype=torch.float32,
36+
name="weight",
37+
),
38+
]
39+
self.run_test_with_dynamic_shape(TestModule(), input_specs)
40+
41+
def test_prelu_with_dynamic_shape_four_dimensions(self):
42+
class TestModule(nn.Module):
43+
def forward(self, x, weight):
44+
return torch.ops.aten._prelu_kernel.default(x, weight)
45+
46+
input_specs = [
47+
Input(
48+
min_shape=(1, 1, 1, 5),
49+
opt_shape=(1, 2, 3, 5),
50+
max_shape=(3, 3, 3, 5),
51+
dtype=torch.float32,
52+
name="x",
53+
),
54+
Input(
55+
min_shape=(1, 1, 1, 1),
56+
opt_shape=(1, 2, 1, 1),
57+
max_shape=(1, 3, 1, 1),
58+
dtype=torch.float32,
59+
name="weight",
60+
),
61+
]
62+
self.run_test_with_dynamic_shape(TestModule(), input_specs)
63+
64+
65+
if __name__ == "__main__":
66+
run_tests()

0 commit comments

Comments
 (0)