Skip to content

Commit 555ed55

Browse files
committed
feat: support converter for torch.log10
1 parent 9a100b6 commit 555ed55

File tree

3 files changed

+82
-0
lines changed

3 files changed

+82
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2708,6 +2708,23 @@ def aten_ops_scalar_tensor(
27082708
)
27092709

27102710

2711+
@dynamo_tensorrt_converter(torch.ops.aten.log10.default)
2712+
def log10(
2713+
ctx: ConversionContext,
2714+
target: Target,
2715+
args: Tuple[Argument, ...],
2716+
kwargs: Dict[str, Argument],
2717+
name: str,
2718+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2719+
return impl.unary.log10(
2720+
ctx,
2721+
target,
2722+
SourceIR.ATEN,
2723+
name,
2724+
args[0],
2725+
)
2726+
2727+
27112728
@dynamo_tensorrt_converter(torch.ops.aten.roll.default)
27122729
@enforce_tensor_types(
27132730
{

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ def log(
6161
)
6262

6363

64+
def log10(
65+
ctx: ConversionContext,
66+
target: Target,
67+
source_ir: Optional[SourceIR],
68+
name: str,
69+
input_val: TRTTensor,
70+
) -> TRTTensor:
71+
log_layer_output = log(ctx, target, source_ir, f"{name}_log", input_val)
72+
73+
ln10 = 2.302585092994046
74+
75+
return impl.elementwise.div(
76+
ctx, target, source_ir, f"{name}_div", log_layer_output, ln10
77+
)
78+
79+
6480
def sqrt(
6581
ctx: ConversionContext,
6682
target: Target,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestLogConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((10,), torch.float),
13+
((1, 20), torch.float),
14+
((2, 3, 4), torch.float),
15+
((2, 3, 4, 5), torch.float),
16+
]
17+
)
18+
def test_log10_float(self, input_shape, dtype):
19+
class log10(nn.Module):
20+
def forward(self, input):
21+
return torch.ops.aten.log10.default(input)
22+
23+
inputs = [torch.randn(input_shape, dtype=dtype)]
24+
self.run_test(
25+
log10(),
26+
inputs,
27+
)
28+
29+
@parameterized.expand(
30+
[
31+
((10,), torch.int, 0, 5),
32+
((1, 20), torch.int32, -10, 10),
33+
((2, 3, 4), torch.int, -5, 5),
34+
]
35+
)
36+
def test_log10_int(self, input_shape, dtype, low, high):
37+
class log10(nn.Module):
38+
def forward(self, input):
39+
return torch.ops.aten.log10.default(input)
40+
41+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
42+
self.run_test(
43+
log10(),
44+
inputs,
45+
)
46+
47+
48+
if __name__ == "__main__":
49+
run_tests()

0 commit comments

Comments
 (0)