1
1
from __future__ import annotations
2
2
3
3
import logging
4
- from functools import partial
4
+ import unittest
5
5
from typing import Any , Callable , Sequence
6
6
7
7
import torch
8
8
import torch ._dynamo as td
9
- from torch ._functorch . aot_autograd import make_boxed_compiler
10
- from torch ._guards import TracingContext
9
+ from torch ._dynamo . utils import detect_fake_mode
10
+ from torch ._functorch . aot_autograd import aot_export_joint_simple
11
11
from torch_tensorrt .dynamo import CompilationSettings
12
12
from torch_tensorrt .dynamo .compile import compile_module
13
13
from torch_tensorrt .dynamo .lowering ._decompositions import get_decompositions
14
- from torch_tensorrt .dynamo .lowering ._freeze_aot_graph import freeze_autograd_gm
15
14
from torch_tensorrt .dynamo .lowering ._pre_aot_lowering import pre_aot_substitutions
16
15
from torch_tensorrt .dynamo .utils import parse_dynamo_kwargs
17
16
18
- from .aot_module import aot_module
19
-
20
17
logger = logging .getLogger (__name__ )
21
18
22
19
@@ -37,8 +34,6 @@ def torch_tensorrt_backend(
37
34
38
35
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
39
36
40
- TracingContext .get ().fake_mode .allow_non_fake_inputs = True
41
-
42
37
return DEFAULT_BACKEND (gm , sample_inputs , ** kwargs )
43
38
44
39
@@ -48,21 +43,26 @@ def aot_torch_tensorrt_aten_backend(
48
43
) -> torch .nn .Module :
49
44
settings = parse_dynamo_kwargs (kwargs )
50
45
51
- custom_backend = partial (
52
- _pretraced_backend ,
53
- settings = settings ,
54
- )
55
-
56
46
# Perform Pre-AOT Lowering for Module-Level Replacement
57
47
gm = pre_aot_substitutions (gm )
58
48
59
- # Invoke AOTAutograd to translate operators to aten
60
- return aot_module (
61
- gm ,
62
- sample_inputs ,
63
- fw_compiler = make_boxed_compiler (custom_backend ),
64
- decompositions = get_decompositions (settings .enable_experimental_decompositions ),
65
- )
49
+ fake_mode = detect_fake_mode (sample_inputs )
50
+
51
+ # Place backend tracing within FakeTensor context allowing nonfake Tensors
52
+ with unittest .mock .patch .object (
53
+ fake_mode , "allow_non_fake_inputs" , True
54
+ ), fake_mode :
55
+ # Invoke AOTAutograd to translate operators to aten
56
+ graph_module = aot_export_joint_simple (
57
+ gm ,
58
+ sample_inputs ,
59
+ trace_joint = False ,
60
+ decompositions = get_decompositions (
61
+ settings .enable_experimental_decompositions
62
+ ),
63
+ )
64
+
65
+ return _pretraced_backend (graph_module , sample_inputs , settings )
66
66
67
67
68
68
def _pretraced_backend (
@@ -82,16 +82,9 @@ def _pretraced_backend(
82
82
try :
83
83
logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
84
84
85
- frozen_gm , unfrozen_indices = freeze_autograd_gm (gm , sample_inputs )
86
- nonfrozen_inputs = [sample_inputs [idx ] for idx in unfrozen_indices ]
87
-
88
- frozen_gm .graph .eliminate_dead_code ()
89
- frozen_gm .graph .lint ()
90
- frozen_gm .recompile ()
91
-
92
85
trt_compiled = compile_module (
93
- frozen_gm ,
94
- nonfrozen_inputs ,
86
+ gm ,
87
+ sample_inputs ,
95
88
settings = settings ,
96
89
)
97
90
return trt_compiled
0 commit comments