Skip to content

Commit c96c00d

Browse files
committed
fix: Move tracer code into try/except
1 parent 3efcea0 commit c96c00d

File tree

1 file changed

+32
-31
lines changed

1 file changed

+32
-31
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,7 @@ def aot_torch_tensorrt_aten_backend(
4343
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
4444
) -> torch.nn.Module:
4545
settings = parse_dynamo_kwargs(kwargs)
46-
47-
# Perform Pre-AOT Lowering for Module-Level Replacement
48-
gm = pre_aot_substitutions(gm)
49-
50-
fake_mode = detect_fake_mode(sample_inputs)
51-
52-
# Place backend tracing within FakeTensor context allowing nonfake Tensors
53-
with unittest.mock.patch.object(
54-
fake_mode, "allow_non_fake_inputs", True
55-
), fake_mode:
56-
# Invoke AOTAutograd to translate operators to aten
57-
graph_module = aot_export_joint_simple(
58-
gm,
59-
sample_inputs,
60-
trace_joint=False,
61-
decompositions=get_decompositions(
62-
settings.enable_experimental_decompositions
63-
),
64-
)
65-
66-
constant_fold(graph_module)
67-
68-
return _pretraced_backend(graph_module, sample_inputs, settings)
46+
return _pretraced_backend(gm, sample_inputs, settings)
6947

7048

7149
def _pretraced_backend(
@@ -83,15 +61,38 @@ def _pretraced_backend(
8361
Compiled FX GraphModule
8462
"""
8563
try:
86-
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
64+
logger.debug("Pre-AOT Autograd graph:\n" + str(gm.graph))
65+
66+
# Perform Pre-AOT Lowering for Module-Level Replacement
67+
gm = pre_aot_substitutions(gm)
68+
69+
fake_mode = detect_fake_mode(sample_inputs)
70+
71+
# Place backend tracing within FakeTensor context allowing nonfake Tensors
72+
with unittest.mock.patch.object(
73+
fake_mode, "allow_non_fake_inputs", True
74+
), fake_mode:
75+
# Invoke AOTAutograd to translate operators to aten
76+
graph_module = aot_export_joint_simple(
77+
gm,
78+
sample_inputs,
79+
trace_joint=False,
80+
decompositions=get_decompositions(
81+
settings.enable_experimental_decompositions
82+
),
83+
)
8784

88-
trt_compiled = compile_module(
89-
gm,
90-
sample_inputs,
91-
settings=settings,
92-
)
93-
return trt_compiled
94-
except AssertionError:
85+
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
86+
87+
constant_fold(graph_module)
88+
89+
trt_compiled = compile_module(
90+
graph_module,
91+
sample_inputs,
92+
settings=settings,
93+
)
94+
return trt_compiled
95+
except (AssertionError, RuntimeError):
9596
if not settings.pass_through_build_failures:
9697
logger.warning(
9798
"TRT conversion failed on the subgraph. See trace above. "

0 commit comments

Comments
 (0)