Skip to content

Commit a7170e2

Browse files
chohk88laikhtewari
authored andcommitted
feat: support aten.log1p converter (#2823)
1 parent 247d632 commit a7170e2

File tree

3 files changed

+141
-34
lines changed

3 files changed

+141
-34
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 51 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,57 @@ def aten_ops_log(
11651165
)
11661166

11671167

1168+
@dynamo_tensorrt_converter(torch.ops.aten.log2.default)
1169+
def aten_ops_log2(
1170+
ctx: ConversionContext,
1171+
target: Target,
1172+
args: Tuple[Argument, ...],
1173+
kwargs: Dict[str, Argument],
1174+
name: str,
1175+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1176+
return impl.unary.log2(
1177+
ctx,
1178+
target,
1179+
SourceIR.ATEN,
1180+
name,
1181+
args[0],
1182+
)
1183+
1184+
1185+
@dynamo_tensorrt_converter(torch.ops.aten.log10.default)
1186+
def aten_ops_log10(
1187+
ctx: ConversionContext,
1188+
target: Target,
1189+
args: Tuple[Argument, ...],
1190+
kwargs: Dict[str, Argument],
1191+
name: str,
1192+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1193+
return impl.unary.log10(
1194+
ctx,
1195+
target,
1196+
SourceIR.ATEN,
1197+
name,
1198+
args[0],
1199+
)
1200+
1201+
1202+
@dynamo_tensorrt_converter(torch.ops.aten.log1p.default)
1203+
def aten_ops_log1p(
1204+
ctx: ConversionContext,
1205+
target: Target,
1206+
args: Tuple[Argument, ...],
1207+
kwargs: Dict[str, Argument],
1208+
name: str,
1209+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
1210+
return impl.unary.log1p(
1211+
ctx,
1212+
target,
1213+
SourceIR.ATEN,
1214+
name,
1215+
args[0],
1216+
)
1217+
1218+
11681219
@dynamo_tensorrt_converter(torch.ops.aten.sqrt.default)
11691220
def aten_ops_sqrt(
11701221
ctx: ConversionContext,
@@ -2849,23 +2900,6 @@ def aten_ops_flip(
28492900
)
28502901

28512902

2852-
@dynamo_tensorrt_converter(torch.ops.aten.log2.default)
2853-
def log2(
2854-
ctx: ConversionContext,
2855-
target: Target,
2856-
args: Tuple[Argument, ...],
2857-
kwargs: Dict[str, Argument],
2858-
name: str,
2859-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2860-
return impl.unary.log2(
2861-
ctx,
2862-
target,
2863-
SourceIR.ATEN,
2864-
name,
2865-
args[0],
2866-
)
2867-
2868-
28692903
@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default)
28702904
def aten_ops_scalar_tensor(
28712905
ctx: ConversionContext,
@@ -2879,23 +2913,6 @@ def aten_ops_scalar_tensor(
28792913
)
28802914

28812915

2882-
@dynamo_tensorrt_converter(torch.ops.aten.log10.default)
2883-
def log10(
2884-
ctx: ConversionContext,
2885-
target: Target,
2886-
args: Tuple[Argument, ...],
2887-
kwargs: Dict[str, Argument],
2888-
name: str,
2889-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2890-
return impl.unary.log10(
2891-
ctx,
2892-
target,
2893-
SourceIR.ATEN,
2894-
name,
2895-
args[0],
2896-
)
2897-
2898-
28992916
@dynamo_tensorrt_converter(torch.ops.aten.roll.default)
29002917
@enforce_tensor_types(
29012918
{

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,23 @@ def log2(
119119
)
120120

121121

122+
def log1p(
123+
ctx: ConversionContext,
124+
target: Target,
125+
source_ir: Optional[SourceIR],
126+
name: str,
127+
input_val: TRTTensor,
128+
) -> TRTTensor:
129+
"""
130+
Computes log(1 + x) for each element of the input tensor.
131+
"""
132+
one_plus_x = impl.elementwise.add(
133+
ctx, target, source_ir, f"{name}_add", input_val, 1
134+
)
135+
136+
return log(ctx, target, source_ir, f"{name}_log", one_plus_x)
137+
138+
122139
def sqrt(
123140
ctx: ConversionContext,
124141
target: Target,
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
from torch_tensorrt import Input
6+
7+
from .harness import DispatchTestCase
8+
9+
10+
class TestLog1pConverter(DispatchTestCase):
11+
@parameterized.expand(
12+
[
13+
((10,), torch.float),
14+
((1, 20), torch.float),
15+
((2, 3, 4), torch.float),
16+
((2, 3, 4, 5), torch.float),
17+
]
18+
)
19+
def test_log1p_float(self, input_shape, dtype):
20+
class Log1p(nn.Module):
21+
def forward(self, input):
22+
return torch.ops.aten.log1p.default(input)
23+
24+
inputs = [
25+
torch.randn(input_shape, dtype=dtype).abs() + 0.001
26+
] # ensure positive input
27+
self.run_test(
28+
Log1p(),
29+
inputs,
30+
)
31+
32+
@parameterized.expand(
33+
[
34+
((10,), torch.int, 0, 5),
35+
((1, 20), torch.int, 0, 10),
36+
((2, 3, 4), torch.int, 0, 5),
37+
((2, 3, 4, 5), torch.int, 0, 5),
38+
]
39+
)
40+
def test_log1p_int(self, input_shape, dtype, low, high):
41+
class Log1p(nn.Module):
42+
def forward(self, input):
43+
return torch.ops.aten.log1p.default(input)
44+
45+
inputs = [
46+
torch.randint(low, high, input_shape, dtype=dtype).abs() + 0.001
47+
] # ensure positive input
48+
self.run_test(
49+
Log1p(),
50+
inputs,
51+
)
52+
53+
@parameterized.expand(
54+
[
55+
(torch.full((1, 20), 2, dtype=torch.float),),
56+
(torch.full((2, 3, 4), 3, dtype=torch.float),),
57+
(torch.full((2, 3, 4, 5), 4, dtype=torch.float),),
58+
]
59+
)
60+
def test_log1p_const_float(self, data):
61+
class Log1p(nn.Module):
62+
def forward(self, input):
63+
return torch.ops.aten.log1p.default(input)
64+
65+
inputs = [data]
66+
self.run_test(
67+
Log1p(),
68+
inputs,
69+
)
70+
71+
72+
if __name__ == "__main__":
73+
run_tests()

0 commit comments

Comments
 (0)