1
1
import logging
2
2
from typing import Sequence
3
3
import torch
4
- from functools import partial
4
+ from torch ._dynamo .utils import detect_fake_mode
5
+ import unittest
5
6
import torch ._dynamo as td
6
7
from torch ._guards import TracingContext
7
8
16
17
partition ,
17
18
get_submod_inputs ,
18
19
)
19
- from torch_tensorrt .dynamo .lowering ._freeze_aot_graph import freeze_autograd_gm
20
20
from torch_tensorrt .dynamo .utils import parse_dynamo_kwargs
21
21
from torch_tensorrt .dynamo .conversion import (
22
22
convert_module ,
23
23
repair_long_or_double_inputs ,
24
24
)
25
25
26
- from torch ._functorch .aot_autograd import make_boxed_compiler
27
- from .aot_module import aot_module
26
+ from torch ._functorch .aot_autograd import aot_export_joint_simple
28
27
29
28
30
29
logger = logging .getLogger (__name__ )
@@ -36,8 +35,6 @@ def torch_tensorrt_backend(
36
35
):
37
36
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
38
37
39
- TracingContext .get ().fake_mode .allow_non_fake_inputs = True
40
-
41
38
return DEFAULT_BACKEND (gm , sample_inputs , ** kwargs )
42
39
43
40
@@ -47,21 +44,25 @@ def aot_torch_tensorrt_aten_backend(
47
44
):
48
45
settings = parse_dynamo_kwargs (kwargs )
49
46
50
- custom_backend = partial (
51
- _pretraced_backend ,
52
- settings = settings ,
53
- )
54
-
55
47
# Perform Pre-AOT Lowering for Module-Level Replacement
56
48
gm = pre_aot_substitutions (gm )
57
49
58
- # Invoke AOTAutograd to translate operators to aten
59
- return aot_module (
60
- gm ,
61
- sample_inputs ,
62
- fw_compiler = make_boxed_compiler (custom_backend ),
63
- decompositions = get_decompositions (),
64
- )
50
+ fake_mode = detect_fake_mode (sample_inputs )
51
+
52
+ # Place backend tracing within FakeTensor context allowing nonfake Tensors
53
+ with unittest .mock .patch .object (
54
+ fake_mode , "allow_non_fake_inputs" , True
55
+ ), fake_mode :
56
+
57
+ # Invoke AOTAutograd to translate operators to aten
58
+ graph_module = aot_export_joint_simple (
59
+ gm ,
60
+ sample_inputs ,
61
+ trace_joint = False ,
62
+ decompositions = get_decompositions (),
63
+ )
64
+
65
+ return _pretraced_backend (graph_module , sample_inputs , settings )
65
66
66
67
67
68
def _pretraced_backend (
@@ -81,16 +82,9 @@ def _pretraced_backend(
81
82
try :
82
83
logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
83
84
84
- frozen_gm , unfrozen_indices = freeze_autograd_gm (gm , sample_inputs )
85
- nonfrozen_inputs = [sample_inputs [idx ] for idx in unfrozen_indices ]
86
-
87
- frozen_gm .graph .eliminate_dead_code ()
88
- frozen_gm .graph .lint ()
89
- frozen_gm .recompile ()
90
-
91
85
trt_compiled = _compile_module (
92
- frozen_gm ,
93
- nonfrozen_inputs ,
86
+ gm ,
87
+ sample_inputs ,
94
88
settings = settings ,
95
89
)
96
90
return trt_compiled
0 commit comments