Skip to content

Commit d75f588

Browse files
authored
feat: Support aten.gelu dynamo converter (#3134)
1 parent 0f8f23d commit d75f588

File tree

4 files changed

+75
-17
lines changed

4 files changed

+75
-17
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
get_positive_dim,
2020
is_only_operator_on_placeholder,
2121
)
22-
from torch_tensorrt.fx.types import TRTTensor
22+
from torch_tensorrt.dynamo.types import TRTTensor
2323

2424
_LOGGER: logging.Logger = logging.getLogger(__name__)
2525

@@ -548,6 +548,24 @@ def aten_ops_hard_sigmoid(
548548
)
549549

550550

551+
@dynamo_tensorrt_converter(torch.ops.aten.gelu.default, supports_dynamic_shapes=True)
552+
def aten_ops_gelu(
553+
ctx: ConversionContext,
554+
target: Target,
555+
args: Tuple[Argument, ...],
556+
kwargs: Dict[str, Argument],
557+
name: str,
558+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
559+
return impl.activation.gelu(
560+
ctx,
561+
target,
562+
SourceIR.ATEN,
563+
name,
564+
args[0],
565+
kwargs.get("approximate", "none"),
566+
)
567+
568+
551569
@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
552570
@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True)
553571
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)

py/torch_tensorrt/dynamo/conversion/impl/activation/ops.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch_tensorrt.dynamo._SourceIR import SourceIR
88
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
99
from torch_tensorrt.dynamo.conversion.impl.activation.base import convert_activation
10-
from torch_tensorrt.fx.types import TRTTensor
10+
from torch_tensorrt.dynamo.types import TRTTensor
1111

1212

1313
def relu(
@@ -327,3 +327,26 @@ def thresholded_relu_fn(x: float) -> float:
327327
alpha=alpha,
328328
dyn_range_fn=thresholded_relu_dyn_range_fn,
329329
)
330+
331+
332+
def gelu(
333+
ctx: ConversionContext,
334+
target: Target,
335+
source_ir: Optional[SourceIR],
336+
name: str,
337+
input_val: TRTTensor,
338+
approximate: str,
339+
) -> TRTTensor:
340+
if approximate == "none":
341+
operation_type = trt.ActivationType.GELU_ERF
342+
elif approximate == "tanh":
343+
operation_type = trt.ActivationType.GELU_TANH
344+
345+
return convert_activation(
346+
ctx,
347+
target,
348+
source_ir,
349+
name,
350+
operation_type,
351+
input_val,
352+
)

py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
aten.fill,
4343
aten.frac,
4444
aten._fused_moving_avg_obs_fq_helper,
45-
aten.gelu,
4645
aten.gelu_backward,
4746
aten.glu_backward,
4847
aten.hardshrink,

tests/py/dynamo/conversion/test_gelu_aten.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,67 @@
1-
import pytest
21
import torch
32
import torch.nn as nn
3+
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
55
from torch_tensorrt import Input
66

77
from .harness import DispatchTestCase
88

99

10-
@pytest.mark.skip(reason="This test will be skipped.")
11-
class TestGeLUConverter(DispatchTestCase):
12-
def test_gelu(self):
10+
class TestGELUConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
("none",),
14+
("tanh",),
15+
]
16+
)
17+
def test_gelu(self, approximate):
1318
class TestModule(nn.Module):
1419
def forward(self, x):
15-
return torch.ops.aten.gelu.default(x)
20+
return torch.ops.aten.gelu.default(x, approximate=approximate)
1621

1722
inputs = [torch.randn(1, 10)]
1823
self.run_test(TestModule(), inputs)
1924

20-
def test_gelu_with_dynamic_shape(self):
25+
@parameterized.expand(
26+
[
27+
("none",),
28+
("tanh",),
29+
]
30+
)
31+
def test_gelu_with_dynamic_shape(self, approximate):
2132
class TestModule(nn.Module):
2233
def forward(self, x):
23-
return torch.ops.aten.gelu.default(x)
34+
return torch.ops.aten.gelu.default(x, approximate=approximate)
2435

2536
input_specs = [
2637
Input(
27-
shape=(-1, -1, -1),
38+
min_shape=(1, 1, 1),
39+
opt_shape=(1, 2, 3),
40+
max_shape=(3, 3, 3),
2841
dtype=torch.float32,
29-
shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))],
3042
),
3143
]
3244
self.run_test_with_dynamic_shape(TestModule(), input_specs)
3345

34-
def test_gelu_with_dynamic_shape_four_dimensions(self):
46+
@parameterized.expand(
47+
[
48+
("none",),
49+
("tanh",),
50+
]
51+
)
52+
def test_gelu_with_dynamic_shape_four_dimensions(self, approximate):
3553
class TestModule(nn.Module):
3654
def forward(self, x):
37-
return torch.ops.aten.gelu.default(x)
55+
return torch.ops.aten.gelu.default(x, approximate=approximate)
3856

3957
input_specs = [
4058
Input(
41-
shape=(-1, -1, -1, -1),
59+
min_shape=(1, 1, 1, 5),
60+
opt_shape=(1, 2, 3, 5),
61+
max_shape=(3, 3, 3, 5),
4262
dtype=torch.float32,
43-
shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))],
4463
),
4564
]
46-
4765
self.run_test_with_dynamic_shape(TestModule(), input_specs)
4866

4967

0 commit comments

Comments
 (0)