Skip to content

Commit c7a26fa

Browse files
authored
fix: Switch all copies to force cast (#2563)
1 parent 03c92a5 commit c7a26fa

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -895,7 +895,7 @@ def aten_ops_clone_copy_dtype(
895895
name,
896896
args[0],
897897
kwargs.get("dtype", args[0].dtype),
898-
force_layer=False,
898+
force_layer=True,
899899
)
900900

901901

@@ -1027,7 +1027,7 @@ def aten_ops_sum(
10271027
name,
10281028
sum_,
10291029
kwargs["output_dtype"],
1030-
force_layer=False,
1030+
force_layer=True,
10311031
)
10321032
else:
10331033
return sum_

tests/py/dynamo/conversion/test_casts.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,21 @@ def forward(self, x):
7575
inputs,
7676
)
7777

78+
def test_to_copy_multiple_returns(self):
79+
class ToCopyReturns(nn.Module):
80+
def forward(self, x):
81+
x_1 = x + 1
82+
y = torch.ops.aten._to_copy.default(x_1, dtype=torch.float)
83+
z = torch.ops.aten._to_copy.default(x_1, dtype=torch.float)
84+
return y, z
85+
86+
inputs = [torch.rand((1, 3, 10))]
87+
self.run_test(
88+
ToCopyReturns(),
89+
inputs,
90+
precision=torch.float,
91+
)
92+
7893

7994
if __name__ == "__main__":
8095
run_tests()

0 commit comments

Comments
 (0)