Skip to content

Commit adb8c1b

Browse files
authored
feat: support aten.arange.start_step dynamo converter (#2505)
1 parent a24d111 commit adb8c1b

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import operator
33
from typing import Dict, Sequence, Tuple, Union
44

5+
import numpy as np
56
import torch
67
from torch.fx.node import Argument, Node, Target
78
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -35,3 +36,14 @@ def generic_evaluator(
3536
f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}"
3637
)
3738
return target(*args)
39+
40+
41+
@dynamo_tensorrt_converter(torch.ops.aten.arange.start_step)
42+
def aten_ops_arange_start_step(
43+
ctx: ConversionContext,
44+
target: Target,
45+
args: Tuple[Argument, ...],
46+
kwargs: Dict[str, Argument],
47+
name: str,
48+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
49+
return np.arange(*args)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestArangeConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(0, 5, 1),
13+
(1, 5, 2),
14+
(3, 5, 3),
15+
(5, 0, -1),
16+
(5, 1, -2),
17+
(5, 3, -3),
18+
]
19+
)
20+
def test_arange(self, start, end, step):
21+
class Arange(nn.Module):
22+
def forward(self, x):
23+
return torch.ops.aten.arange.start_step(start, x.shape[0], step)
24+
25+
inputs = [torch.randn(end, 1)]
26+
self.run_test(
27+
Arange(),
28+
inputs,
29+
use_dynamo_tracer=True,
30+
)
31+
32+
33+
if __name__ == "__main__":
34+
run_tests()

0 commit comments

Comments
 (0)