@@ -43,29 +43,7 @@ def aot_torch_tensorrt_aten_backend(
43
43
gm : torch .fx .GraphModule , sample_inputs : Sequence [torch .Tensor ], ** kwargs : Any
44
44
) -> torch .nn .Module :
45
45
settings = parse_dynamo_kwargs (kwargs )
46
-
47
- # Perform Pre-AOT Lowering for Module-Level Replacement
48
- gm = pre_aot_substitutions (gm )
49
-
50
- fake_mode = detect_fake_mode (sample_inputs )
51
-
52
- # Place backend tracing within FakeTensor context allowing nonfake Tensors
53
- with unittest .mock .patch .object (
54
- fake_mode , "allow_non_fake_inputs" , True
55
- ), fake_mode :
56
- # Invoke AOTAutograd to translate operators to aten
57
- graph_module = aot_export_joint_simple (
58
- gm ,
59
- sample_inputs ,
60
- trace_joint = False ,
61
- decompositions = get_decompositions (
62
- settings .enable_experimental_decompositions
63
- ),
64
- )
65
-
66
- constant_fold (graph_module )
67
-
68
- return _pretraced_backend (graph_module , sample_inputs , settings )
46
+ return _pretraced_backend (gm , sample_inputs , settings )
69
47
70
48
71
49
def _pretraced_backend (
@@ -83,15 +61,38 @@ def _pretraced_backend(
83
61
Compiled FX GraphModule
84
62
"""
85
63
try :
86
- logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
64
+ logger .debug ("Pre-AOT Autograd graph:\n " + str (gm .graph ))
65
+
66
+ # Perform Pre-AOT Lowering for Module-Level Replacement
67
+ gm = pre_aot_substitutions (gm )
68
+
69
+ fake_mode = detect_fake_mode (sample_inputs )
70
+
71
+ # Place backend tracing within FakeTensor context allowing nonfake Tensors
72
+ with unittest .mock .patch .object (
73
+ fake_mode , "allow_non_fake_inputs" , True
74
+ ), fake_mode :
75
+ # Invoke AOTAutograd to translate operators to aten
76
+ graph_module = aot_export_joint_simple (
77
+ gm ,
78
+ sample_inputs ,
79
+ trace_joint = False ,
80
+ decompositions = get_decompositions (
81
+ settings .enable_experimental_decompositions
82
+ ),
83
+ )
87
84
88
- trt_compiled = compile_module (
89
- gm ,
90
- sample_inputs ,
91
- settings = settings ,
92
- )
93
- return trt_compiled
94
- except AssertionError :
85
+ logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
86
+
87
+ constant_fold (graph_module )
88
+
89
+ trt_compiled = compile_module (
90
+ graph_module ,
91
+ sample_inputs ,
92
+ settings = settings ,
93
+ )
94
+ return trt_compiled
95
+ except (AssertionError , RuntimeError ):
95
96
if not settings .pass_through_build_failures :
96
97
logger .warning (
97
98
"TRT conversion failed on the subgraph. See trace above. "
0 commit comments