Skip to content

Commit a42dfb3

Browse files
authored
feat: support aten.copy dynamo converter (#2550)
1 parent 80db13c commit a42dfb3

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2547,3 +2547,28 @@ def aten_ops_trunc(
25472547
name,
25482548
args[0],
25492549
)
2550+
2551+
2552+
@dynamo_tensorrt_converter(torch.ops.aten.copy.default)
2553+
@enforce_tensor_types(
2554+
{
2555+
1: (TRTTensor,),
2556+
}
2557+
)
2558+
def aten_ops_copy(
2559+
ctx: ConversionContext,
2560+
target: Target,
2561+
args: Tuple[Argument, ...],
2562+
kwargs: Dict[str, Argument],
2563+
name: str,
2564+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2565+
src = args[1]
2566+
return impl.cast.to_copy(
2567+
ctx,
2568+
target,
2569+
SourceIR.ATEN,
2570+
name,
2571+
src,
2572+
src.dtype,
2573+
force_layer=True,
2574+
)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
import torch.nn as nn
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import run_tests
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestCopyConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
((3,), (3,), False),
13+
((1, 10), (1, 10), False),
14+
((2, 3, 4), (2, 3, 4), True),
15+
((2, 3, 4, 5), (2, 3, 4, 5), True),
16+
]
17+
)
18+
def test_copy_float(self, input_shape, src_shape, non_blocking):
19+
class Copy(nn.Module):
20+
def forward(self, input, src):
21+
return torch.ops.aten.copy.default(input, src, non_blocking)
22+
23+
inputs = [torch.randn(input_shape), torch.randn(src_shape)]
24+
self.run_test(
25+
Copy(),
26+
inputs,
27+
)
28+
29+
30+
if __name__ == "__main__":
31+
run_tests()

0 commit comments

Comments
 (0)