Skip to content

Commit 60bfa04

Browse files
authored
fix: Error with aten.view across Tensor memory (#2464)
1 parent 867dc7b commit 60bfa04

File tree

3 files changed

+110
-1
lines changed

3 files changed

+110
-1
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
1212
from .repair_input_as_output import repair_input_as_output
1313
from .replace_max_pool_with_indices import replace_max_pool_with_indices
14+
from .view_to_reshape import view_to_reshape
1415

1516
ATEN_LOWERING_PASSES = DynamoPassManager.build_from_passlist(
1617
[
@@ -21,6 +22,7 @@
2122
lower_linear,
2223
fuse_prims_broadcast,
2324
replace_max_pool_with_indices,
25+
view_to_reshape,
2426
]
2527
)
2628

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import logging
2+
from typing import Callable, List, Sequence, Tuple
3+
4+
import torch
5+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
6+
clean_up_graph_after_modifications,
7+
)
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
def view_to_reshape(
13+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
14+
) -> torch.fx.GraphModule:
15+
"""Replace aten.view with an equivalent implementation which avoids Tensor memory issues"""
16+
orig, replacement = view_replacement()
17+
18+
if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
19+
gm = clean_up_graph_after_modifications(gm)
20+
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
21+
22+
return gm
23+
24+
25+
def view_replacement() -> (
26+
Tuple[
27+
torch.fx.GraphModule,
28+
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
29+
]
30+
):
31+
"""Constructs the original and replacement functions for view"""
32+
33+
# Original graph
34+
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
35+
return torch.ops.aten.view.default(input, shape)
36+
37+
# Replacement graph
38+
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
39+
return torch.ops.aten.reshape.default(input, shape)
40+
41+
return orig, replacement

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
2-
import torch_tensorrt
32
from torch.testing._internal.common_utils import TestCase, run_tests
43

4+
import torch_tensorrt
5+
56
from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
67

78

@@ -375,5 +376,70 @@ def forward(self, input, weight, bias):
375376
torch._dynamo.reset()
376377

377378

379+
class TestLowerViewToReshape(TestCase):
380+
def test_view_to_reshape(self):
381+
class ViewToReshape(torch.nn.Module):
382+
def forward(self, input):
383+
out = torch.ops.aten.view.default(input, (1, 1, -1))
384+
return out
385+
386+
inputs = [
387+
torch.rand((3, 4, 5, 32)).cuda(),
388+
]
389+
390+
fx_graph = torch.fx.symbolic_trace(ViewToReshape())
391+
expected_ops = {torch.ops.aten.reshape.default}
392+
unexpected_ops = {
393+
torch.ops.aten.view.default,
394+
}
395+
396+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
397+
fx_graph,
398+
inputs,
399+
expected_ops=expected_ops,
400+
unexpected_ops=unexpected_ops,
401+
min_block_size=1,
402+
)
403+
404+
self.assertEquals(
405+
len(unexpected_ops_seen),
406+
0,
407+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
408+
)
409+
410+
self.assertEquals(
411+
len(expected_ops_unseen),
412+
0,
413+
f"The following expected ops were not encountered: {expected_ops_unseen}",
414+
)
415+
torch._dynamo.reset()
416+
417+
# Validate that the results between Torch and Torch-TRT are similar
418+
optimized_model = torch_tensorrt.compile(
419+
fx_graph,
420+
"torch_compile",
421+
inputs,
422+
min_block_size=1,
423+
pass_through_build_failures=True,
424+
)
425+
optimized_model_results = torch.cat(
426+
[tensor.detach().cpu() for tensor in optimized_model(*inputs)]
427+
)
428+
torch_model_results = torch.cat(
429+
[tensor.detach().cpu() for tensor in fx_graph(*inputs)]
430+
)
431+
432+
max_diff = float(
433+
torch.max(torch.abs(optimized_model_results - torch_model_results))
434+
)
435+
self.assertAlmostEqual(
436+
max_diff,
437+
0,
438+
DECIMALS_OF_AGREEMENT,
439+
msg=f"ViewToReshape TRT outputs don't match with the original model.",
440+
)
441+
torch._dynamo.reset()
442+
443+
378444
if __name__ == "__main__":
379445
run_tests()

0 commit comments

Comments
 (0)