Skip to content

Commit b97a266

Browse files
chohk88laikhtewari
authored andcommitted
feat: support aten.as_strided converter (#2735)
1 parent 4ae4516 commit b97a266

File tree

3 files changed

+182
-0
lines changed

3 files changed

+182
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,40 @@ def aten_ops_tile(
800800
)
801801

802802

803+
def zero_output_validator(node: Node) -> bool:
804+
if 0 in node.args[1]:
805+
_LOGGER.debug(
806+
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
807+
)
808+
return False
809+
else:
810+
return True
811+
812+
813+
@dynamo_tensorrt_converter(
814+
torch.ops.aten.as_strided.default,
815+
capability_validator=zero_output_validator,
816+
)
817+
@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default)
818+
def aten_ops_as_strided(
819+
ctx: ConversionContext,
820+
target: Target,
821+
args: Tuple[Argument, ...],
822+
kwargs: Dict[str, Argument],
823+
name: str,
824+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
825+
return impl.slice.as_strided(
826+
ctx,
827+
target,
828+
source_ir=SourceIR.ATEN,
829+
name=name,
830+
input=args[0],
831+
size=args[1],
832+
stride=args[2],
833+
storage_offset=args_bounds_check(args, 3, None),
834+
)
835+
836+
803837
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
804838
@enforce_tensor_types(
805839
{

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch_tensorrt.dynamo.conversion import impl
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1010
from torch_tensorrt.dynamo.conversion.converter_utils import (
11+
flatten_dims,
1112
get_positive_dim,
1213
get_trt_tensor,
1314
)
@@ -259,3 +260,61 @@ def flip(
259260
)
260261
set_layer_name(layer, target, name, source_ir)
261262
return layer.get_output(0)
263+
264+
265+
def as_strided(
266+
ctx: ConversionContext,
267+
target: Target,
268+
source_ir: Optional[SourceIR],
269+
name: str,
270+
input: TRTTensor,
271+
size: Sequence[int],
272+
stride: Sequence[int],
273+
storage_offset: Optional[int],
274+
) -> TRTTensor:
275+
# Ensure storage_offset is an integer before passing to nested
276+
if storage_offset is None:
277+
storage_offset = 0
278+
279+
flatten_shape = flatten_dims(input, 0, -1)
280+
flatten_output = impl.shuffle.reshape(
281+
ctx, target, source_ir, f"{name}_reshape_flatten_output", input, flatten_shape
282+
)
283+
284+
indices = []
285+
286+
# Recursive function to compute indices for as_strided operation
287+
def nested(
288+
rank: int, size: Sequence[int], stride: Sequence[int], current: int, dim: int
289+
) -> None:
290+
if (
291+
dim == rank
292+
): # If the current dimension equals the rank, append the computed index
293+
indices.append(current)
294+
return
295+
for i in range(size[dim]): # Recursively compute indices across dimensions
296+
nested(
297+
rank, size, stride, current + stride[dim] * i, dim + 1
298+
) # Calculate the index for the current dimension and recursively explore further dimensions
299+
300+
nested(len(size), size, stride, storage_offset, 0)
301+
302+
indices = np.array(indices, dtype=np.int32)
303+
304+
indices_tensor = get_trt_tensor(ctx, indices, f"{name}_indices")
305+
306+
# Use gather to reorder elements based on computed indices
307+
gather_layer = ctx.net.add_gather(flatten_output, indices_tensor, axis=0)
308+
gather_output = gather_layer.get_output(0)
309+
310+
# Reshape the gathered tensor to the desired size
311+
reshape_output = impl.shuffle.reshape(
312+
ctx,
313+
target,
314+
source_ir,
315+
f"{name}_reshape_gather_output",
316+
gather_output,
317+
tuple(size),
318+
)
319+
320+
return reshape_output
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestAsStridedConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(
13+
(5, 5),
14+
(2, 3),
15+
(1, 2),
16+
0,
17+
),
18+
(
19+
(5, 5),
20+
(2, 3),
21+
(2, 2),
22+
1,
23+
),
24+
(
25+
(20, 20),
26+
(2, 3, 2),
27+
(2, 2, 2),
28+
0,
29+
),
30+
(
31+
(8, 8, 8),
32+
(2, 2, 3),
33+
(1, 2, 2),
34+
1,
35+
),
36+
(
37+
(200, 200, 200),
38+
(9, 9, 3, 2),
39+
(2, 2, 2, 3),
40+
1,
41+
),
42+
(
43+
(10, 25, 12),
44+
(3, 7, 3),
45+
(2, 1, 3),
46+
1,
47+
),
48+
(
49+
(10, 25, 12),
50+
(3, 7, 3),
51+
(2, 0, 3),
52+
1,
53+
),
54+
(
55+
(10, 25, 12, 100),
56+
(6, 5, 7, 10),
57+
(0, 0, 0, 0),
58+
0,
59+
),
60+
(
61+
(10, 25, 12, 100),
62+
(6, 5, 7, 10),
63+
(0, 0, 0, 0),
64+
1,
65+
),
66+
]
67+
)
68+
def test_as_strided(
69+
self,
70+
input_shape,
71+
output_size,
72+
stride,
73+
storage_offset=0,
74+
):
75+
class TestModule(torch.nn.Module):
76+
def forward(self, x):
77+
return torch.ops.aten.as_strided.default(
78+
x, output_size, stride, storage_offset
79+
)
80+
81+
inputs = [torch.randn(input_shape)]
82+
self.run_test(
83+
TestModule(),
84+
inputs,
85+
)
86+
87+
88+
if __name__ == "__main__":
89+
run_tests()

0 commit comments

Comments
 (0)