Skip to content

Commit a546a9f

Browse files
chohk88laikhtewari
authored andcommitted
feat: support aten.expm1 converter (#2714)
1 parent 1e56b61 commit a546a9f

File tree

3 files changed

+112
-0
lines changed

3 files changed

+112
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,23 @@ def aten_ops_exp(
11361136
)
11371137

11381138

1139+
@dynamo_tensorrt_converter(torch.ops.aten.expm1.default)
1140+
def aten_ops_expm1(
1141+
ctx: ConversionContext,
1142+
target: Target,
1143+
args: Tuple[Argument, ...],
1144+
kwargs: Dict[str, Argument],
1145+
name: str,
1146+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1147+
return impl.unary.expm1(
1148+
ctx,
1149+
target,
1150+
SourceIR.ATEN,
1151+
name,
1152+
args[0],
1153+
)
1154+
1155+
11391156
@dynamo_tensorrt_converter(torch.ops.aten.log.default)
11401157
def aten_ops_log(
11411158
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,32 @@ def exp(
4444
)
4545

4646

47+
def expm1(
48+
ctx: ConversionContext,
49+
target: Target,
50+
source_ir: Optional[SourceIR],
51+
name: str,
52+
input_val: TRTTensor,
53+
) -> TRTTensor:
54+
"""
55+
Computes e^x - 1 for each element of the input tensor.
56+
57+
Args:
58+
ctx (ConversionContext): TensorRT ConversionContext object.
59+
target (Target): fx node target.
60+
source_ir (SourceIR): Source IR calling the function
61+
name (str): Name of the fx node with optional suffix.
62+
input_val (TRTTensor): The input tensor.
63+
64+
Returns:
65+
TRTTensor: A TensorRT tensor represent the result of expm1 operator.
66+
"""
67+
# Compute e^x for each element of the input tensor
68+
exp_result = exp(ctx, target, source_ir, f"{name}_exp", input_val)
69+
70+
return impl.elementwise.sub(ctx, target, source_ir, f"{name}_sub", exp_result, 1)
71+
72+
4773
def log(
4874
ctx: ConversionContext,
4975
target: Target,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from math import exp
2+
3+
import torch
4+
import torch.nn as nn
5+
from parameterized import parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
8+
from .harness import DispatchTestCase
9+
10+
11+
class TestExpConverter(DispatchTestCase):
12+
@parameterized.expand(
13+
[
14+
((10,), torch.float),
15+
((1, 20), torch.float),
16+
((2, 3, 4), torch.float),
17+
((2, 3, 4, 5), torch.float),
18+
]
19+
)
20+
def test_expm1_float(self, input_shape, dtype):
21+
class expm1(nn.Module):
22+
def forward(self, input):
23+
return torch.ops.aten.expm1.default(input)
24+
25+
inputs = [torch.randn(input_shape, dtype=dtype)]
26+
self.run_test(
27+
expm1(),
28+
inputs,
29+
)
30+
31+
@parameterized.expand(
32+
[
33+
(torch.full((1, 20), exp(1), dtype=torch.float),),
34+
(torch.full((2, 3, 4), exp(2), dtype=torch.float),),
35+
(torch.full((2, 3, 4, 5), exp(3), dtype=torch.float),),
36+
]
37+
)
38+
def test_expm1_exp_const_float(self, data):
39+
class expm1(nn.Module):
40+
def forward(self, input):
41+
return torch.ops.aten.expm1.default(input)
42+
43+
inputs = [data]
44+
self.run_test(
45+
expm1(),
46+
inputs,
47+
)
48+
49+
@parameterized.expand(
50+
[
51+
((10,), torch.int, 0, 5),
52+
((1, 20), torch.int32, -10, 10),
53+
((2, 3, 4), torch.int, -5, 5),
54+
]
55+
)
56+
def test_exp_int(self, input_shape, dtype, low, high):
57+
class expm1(nn.Module):
58+
def forward(self, input):
59+
return torch.ops.aten.expm1.default(input)
60+
61+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
62+
self.run_test(
63+
expm1(),
64+
inputs,
65+
)
66+
67+
68+
if __name__ == "__main__":
69+
run_tests()

0 commit comments

Comments
 (0)