Skip to content

Commit e642f86

Browse files
authored
feat: support aten.scalar_tensor dynamo converter (#2595)
1 parent 6848571 commit e642f86

File tree

3 files changed

+127
-2
lines changed

3 files changed

+127
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2670,3 +2670,16 @@ def aten_ops_flip(
26702670
args[0],
26712671
args[1],
26722672
)
2673+
2674+
2675+
@dynamo_tensorrt_converter(torch.ops.aten.scalar_tensor.default)
2676+
def aten_ops_scalar_tensor(
2677+
ctx: ConversionContext,
2678+
target: Target,
2679+
args: Tuple[Argument, ...],
2680+
kwargs: Dict[str, Argument],
2681+
name: str,
2682+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2683+
return impl.unary.scalar_tensor(
2684+
ctx, target, SourceIR.ATEN, name, args[0], dtype=kwargs.get("dtype")
2685+
)

py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22

3+
import numpy as np
34
import tensorrt as trt
5+
import torch
46
import torch_tensorrt.dynamo.conversion.impl as impl
57
from torch.fx.node import Target
68
from torch_tensorrt.dynamo._SourceIR import SourceIR
@@ -10,7 +12,8 @@
1012
get_trt_tensor,
1113
)
1214
from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary
13-
from torch_tensorrt.fx.types import TRTTensor
15+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
16+
from torch_tensorrt.fx.types import TRTDataType, TRTTensor
1417

1518

1619
def exp(
@@ -459,3 +462,17 @@ def trunc(
459462
return impl.elementwise.trunc_div(
460463
ctx, target, source_ir, f"{name}_trunc", input_val, dividend
461464
)
465+
466+
467+
def scalar_tensor(
468+
ctx: ConversionContext,
469+
target: Target,
470+
source_ir: Optional[SourceIR],
471+
name: str,
472+
scalar: Union[int, float, bool],
473+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None,
474+
) -> TRTTensor:
475+
tensor = get_trt_tensor(ctx, scalar, f"{name}_scalar_tensor", dtype)
476+
identity_layer = ctx.net.add_identity(tensor)
477+
set_layer_name(identity_layer, target, name, source_ir)
478+
return identity_layer.get_output(0)
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestScalarTensorConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(-2.00001,),
13+
(-1.3,),
14+
(-0.0,),
15+
(1.0,),
16+
(2.99,),
17+
]
18+
)
19+
def test_scalar_tensor_float(self, scalar):
20+
class ScalarTensor(nn.Module):
21+
def forward(self):
22+
return torch.ops.aten.scalar_tensor.default(scalar)
23+
24+
inputs = []
25+
self.run_test(
26+
ScalarTensor(),
27+
inputs,
28+
)
29+
30+
@parameterized.expand(
31+
[
32+
(-9999,),
33+
(-1,),
34+
(0,),
35+
(2,),
36+
(99999,),
37+
]
38+
)
39+
def test_scalar_tensor_int(self, scalar):
40+
class ScalarTensor(nn.Module):
41+
def forward(self):
42+
return torch.ops.aten.scalar_tensor.default(scalar)
43+
44+
inputs = []
45+
self.run_test(
46+
ScalarTensor(),
47+
inputs,
48+
)
49+
50+
@parameterized.expand(
51+
[
52+
(True,),
53+
(False,),
54+
]
55+
)
56+
def test_scalar_tensor_bool(self, scalar):
57+
class ScalarTensor(nn.Module):
58+
def forward(self):
59+
return torch.ops.aten.scalar_tensor.default(scalar)
60+
61+
inputs = []
62+
self.run_test(
63+
ScalarTensor(),
64+
inputs,
65+
)
66+
67+
@parameterized.expand(
68+
[
69+
(-9999, torch.int),
70+
(-2.00001, torch.float),
71+
(-1, torch.float),
72+
(0, torch.int),
73+
(-0.0, torch.float),
74+
(1.0, torch.int),
75+
(2.99, torch.float),
76+
(9999999, None),
77+
(9999999.99999, None),
78+
(True, torch.bool),
79+
]
80+
)
81+
def test_scalar_tensor_dtype(self, scalar, dtype):
82+
class ScalarTensor(nn.Module):
83+
def forward(self):
84+
return torch.ops.aten.scalar_tensor.default(scalar, dtype=dtype)
85+
86+
inputs = []
87+
self.run_test(
88+
ScalarTensor(),
89+
inputs,
90+
output_dtypes=None if dtype is None else [dtype],
91+
)
92+
93+
94+
if __name__ == "__main__":
95+
run_tests()

0 commit comments

Comments
 (0)