Skip to content

Commit cd7808f

Browse files
committed
[feat] support converter for torch.log2
1 parent 6848571 commit cd7808f

File tree

3 files changed

+81
-0
lines changed

3 files changed

+81
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2670,3 +2670,20 @@ def aten_ops_flip(
26702670
args[0],
26712671
args[1],
26722672
)
2673+
2674+
2675+
@dynamo_tensorrt_converter(torch.log2)
2676+
def log2(
2677+
ctx: ConversionContext,
2678+
target: Target,
2679+
args: Tuple[Argument, ...],
2680+
kwargs: Dict[str, Argument],
2681+
name: str,
2682+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2683+
return impl.unary.log2(
2684+
ctx,
2685+
target,
2686+
SourceIR.ATEN,
2687+
name,
2688+
args[0],
2689+
)

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
cast_trt_tensor,
1010
get_trt_tensor,
1111
)
12+
from torch_tensorrt.dynamo.conversion.impl.elementwise import div
1213
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
1314
from torch_tensorrt.fx.types import TRTTensor
1415

@@ -58,6 +59,20 @@ def log(
5859
)
5960

6061

62+
def log2(
63+
ctx: ConversionContext,
64+
target: Target,
65+
source_ir: Optional[SourceIR],
66+
name: str,
67+
input_val: TRTTensor,
68+
) -> TRTTensor:
69+
log_layer_output = log(ctx, target, source_ir, f"{name}_log", input_val)
70+
71+
ln2 = 0.693147180559945309
72+
73+
return div(ctx, target, source_ir, f"{name}_div", log_layer_output, ln2)
74+
75+
6176
def sqrt(
6277
ctx: ConversionContext,
6378
target: Target,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestLogConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((10,), torch.float),
13+
((1, 20), torch.float),
14+
((2, 3, 4), torch.float),
15+
((2, 3, 4, 5), torch.float),
16+
]
17+
)
18+
def test_log_float(self, input_shape, dtype):
19+
class log(nn.Module):
20+
def forward(self, input):
21+
return torch.log2(input)
22+
23+
inputs = [torch.randn(input_shape, dtype=dtype)]
24+
self.run_test(
25+
log(),
26+
inputs,
27+
)
28+
29+
@parameterized.expand(
30+
[
31+
((10,), torch.int, 0, 5),
32+
((1, 20), torch.int32, -10, 10),
33+
((2, 3, 4), torch.int, -5, 5),
34+
]
35+
)
36+
def test_log_int(self, input_shape, dtype, low, high):
37+
class log(nn.Module):
38+
def forward(self, input):
39+
return torch.log2(input)
40+
41+
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
42+
self.run_test(
43+
log(),
44+
inputs,
45+
)
46+
47+
48+
if __name__ == "__main__":
49+
run_tests()

0 commit comments

Comments
 (0)