Skip to content

Commit 1db0582

Browse files
committed
refactor: Promote lowering to top level, remove input_tensor_spec, rename fx2trt
Signed-off-by: Dheeraj Peri <[email protected]> chore: apply linting Signed-off-by: Dheeraj Peri <[email protected]>
1 parent a41c190 commit 1db0582

File tree

18 files changed

+197
-304
lines changed

18 files changed

+197
-304
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import copy
2+
import sys
3+
from contextlib import contextmanager
4+
from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union
5+
from packaging import version
6+
7+
import torch
8+
import torch._dynamo as torchdynamo
9+
10+
from torch_tensorrt.fx.utils import req_torch_version
11+
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
12+
compose_bmm,
13+
compose_chunk,
14+
compose_getitem_slice,
15+
remove_ops,
16+
replace_aten_op_with_indices,
17+
replace_aten_reshape_alias_with_replace,
18+
replace_builtin_ops,
19+
replace_inplace_ops,
20+
replace_native_layernorm_with_layernorm,
21+
replace_transpose_mm_op_with_linear,
22+
run_const_fold,
23+
)
24+
from typing_extensions import TypeAlias
25+
26+
Value: TypeAlias = Union[
27+
Tuple["Value", ...],
28+
List["Value"],
29+
Dict[str, "Value"],
30+
]
31+
32+
33+
class DynamoConfig:
34+
"""
35+
Manage Exir-specific configurations of Dynamo.
36+
"""
37+
38+
def __init__(
39+
self,
40+
capture_scalar_outputs: bool = True,
41+
guard_nn_modules: bool = True,
42+
dynamic_shapes: bool = True,
43+
specialize_int: bool = True,
44+
verbose: bool = True,
45+
) -> None:
46+
47+
self.capture_scalar_outputs = capture_scalar_outputs
48+
self.guard_nn_modules = guard_nn_modules
49+
self.dynamic_shapes = dynamic_shapes
50+
self.specialize_int = specialize_int
51+
self.verbose = verbose
52+
53+
def activate(self) -> None:
54+
torchdynamo.config.capture_scalar_outputs = self.capture_scalar_outputs
55+
torchdynamo.config.guard_nn_modules = self.guard_nn_modules
56+
torchdynamo.config.dynamic_shapes = self.dynamic_shapes
57+
torchdynamo.config.specialize_int = self.specialize_int
58+
torchdynamo.config.verbose = self.verbose
59+
60+
def deactivate(self) -> None:
61+
torchdynamo.config.capture_scalar_outputs = True
62+
torchdynamo.config.guard_nn_modules = True
63+
torchdynamo.config.dynamic_shapes = True
64+
torchdynamo.config.specialize_int = True
65+
torchdynamo.config.verbose = True
66+
67+
68+
@contextmanager
69+
def using_config(config: DynamoConfig) -> Generator[DynamoConfig, None, None]:
70+
config.activate()
71+
try:
72+
yield config
73+
finally:
74+
config.deactivate()
75+
76+
77+
@contextmanager
78+
def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None, None]:
79+
"""
80+
Temporarily increase the python interpreter stack recursion limit.
81+
This is mostly used for pickling large scale modules.
82+
"""
83+
default = sys.getrecursionlimit()
84+
if limit > default:
85+
sys.setrecursionlimit(limit)
86+
try:
87+
yield
88+
finally:
89+
sys.setrecursionlimit(default)
90+
91+
92+
@req_torch_version("2.dev")
93+
def dynamo_trace(
94+
f: Callable[..., Value],
95+
# pyre-ignore
96+
args: Tuple[Any, ...],
97+
aten_graph: bool,
98+
tracing_mode: str = "real",
99+
dynamo_config: Optional[DynamoConfig] = None,
100+
) -> Tuple[torch.fx.GraphModule, Set]:
101+
"""
102+
TODO: Once we fully migrate to torchdynamo frontend, we will remove
103+
this config option alltogether. For now, it helps with quick
104+
experiments with playing around with TorchDynamo
105+
"""
106+
if dynamo_config is None:
107+
dynamo_config = DynamoConfig()
108+
with using_config(dynamo_config), setting_python_recursive_limit(2000):
109+
torchdynamo.reset()
110+
try:
111+
return torchdynamo.export(
112+
f,
113+
*copy.deepcopy(args),
114+
aten_graph=aten_graph,
115+
tracing_mode=tracing_mode,
116+
)
117+
except torchdynamo.exc.Unsupported as exc:
118+
raise RuntimeError(
119+
"The user code is using a feature we don't support. "
120+
"Please try torchdynamo.explain() to get possible the reasons",
121+
) from exc
122+
except Exception as exc:
123+
raise RuntimeError(
124+
"torchdynamo internal error occured. Please see above stacktrace"
125+
) from exc
126+
127+
128+
@req_torch_version("2.dev")
129+
def trace(model, inputs, **kwargs):
130+
"""
131+
Optimized trace with necessary passes which re-compose some ops or replace some ops
132+
These passes should be general and functional purpose
133+
"""
134+
passes_list = [
135+
compose_bmm,
136+
compose_chunk,
137+
compose_getitem_slice,
138+
replace_aten_reshape_alias_with_replace,
139+
replace_aten_op_with_indices,
140+
replace_transpose_mm_op_with_linear, # after compose_bmm
141+
replace_native_layernorm_with_layernorm,
142+
remove_ops,
143+
replace_builtin_ops, # after replace_native_layernorm_with_layernorm
144+
replace_inplace_ops, # remove it once functionalization is enabled
145+
]
146+
147+
fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
148+
print(fx_module.graph)
149+
for passes in passes_list:
150+
pr: PassResult = passes(fx_module)
151+
fx_module = pr.graph_module
152+
153+
fx_module(*inputs)
154+
155+
fx_module = run_const_fold(fx_module)
156+
print(fx_module.graph)
157+
return fx_module

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
import torch._dynamo as td
66

77
from torch_tensorrt.dynamo import CompilationSettings
8-
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
8+
from torch_tensorrt.dynamo.lowering._decompositions import (
99
get_decompositions,
1010
)
11-
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
11+
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import (
1212
pre_aot_substitutions,
1313
)
14-
from torch_tensorrt.dynamo.backend.lowering._partition import (
14+
from torch_tensorrt.dynamo.lowering._partition import (
1515
partition,
1616
get_submod_inputs,
1717
)

py/torch_tensorrt/dynamo/compile.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from torch_tensorrt.fx.utils import LowerPrecision
1010
from torch.fx.passes.pass_manager import PassManager
1111
from torch.fx.passes.shape_prop import ShapeProp
12-
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
12+
from torch_tensorrt.dynamo.aten_tracer import trace
1313
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
14-
from torch_tensorrt.dynamo.backend.lowering import (
14+
from torch_tensorrt.dynamo.lowering import (
1515
fuse_permute_linear,
1616
fuse_permute_matmul,
1717
)
@@ -113,8 +113,7 @@ def compile(
113113
model = trace(gm, inputs, **kwargs)
114114

115115
if kwargs.get("use_capability_partitioner", None):
116-
traced_model = trace(model)
117-
model = lower_model(traced_model, inputs)
116+
model = lower_model(model, inputs)
118117
return _compile_module(model, inputs, settings)
119118
else:
120119
split_result = lower_model_using_trt_splitter(model, inputs)
@@ -146,33 +145,6 @@ def _compile_graph(
146145
return split_result.split_module
147146

148147

149-
def trace(
150-
model: torch.nn.Module,
151-
inputs: Any,
152-
**kwargs,
153-
):
154-
"""Create torch.compile backend given specified arguments
155-
156-
Args:
157-
precision: Model Layer precision
158-
debug: Whether to print out verbose debugging information
159-
workspace_size: Workspace TRT is allowed to use for the module (0 is default)
160-
min_block_size: Minimum number of operators per TRT-Engine Block
161-
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
162-
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
163-
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
164-
version_compatible: Provide version forward-compatibility for engine plan files
165-
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
166-
searching for more optimization options. TRT defaults to 3
167-
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
168-
Returns:
169-
Backend for torch.compile
170-
"""
171-
model = aten_tracer.opt_trace(model, inputs)
172-
173-
return model
174-
175-
176148
def lower_model_using_trt_splitter(model: torch.nn.Module, inputs: Any, **kwargs):
177149
# Perform basic lowering
178150
model = lower_model(model, inputs)
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
from .input_tensor_spec import *
2-
from .fx2trt import *
1+
from .trt_interpreter import *
32
from .conversion import *

py/torch_tensorrt/dynamo/conversion/conversion.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
import io
44
from torch_tensorrt.fx.trt_module import TRTModule
55
from torch_tensorrt.dynamo import CompilationSettings
6-
from torch_tensorrt.dynamo.conversion import (
7-
InputTensorSpec,
8-
TRTInterpreter,
9-
)
6+
from torch_tensorrt import Input
7+
from torch_tensorrt.dynamo.conversion import TRTInterpreter
8+
109

1110
import tensorrt as trt
1211

@@ -34,14 +33,12 @@ def convert_module(
3433
module_outputs = [module_outputs]
3534

3635
output_dtypes = list(output.dtype for output in module_outputs)
37-
3836
interpreter = TRTInterpreter(
3937
module,
40-
InputTensorSpec.from_tensors(inputs),
38+
Input.from_tensors(inputs),
4139
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
4240
output_dtypes=output_dtypes,
4341
)
44-
4542
interpreter_result = interpreter.run(
4643
workspace_size=settings.workspace_size,
4744
lower_precision=settings.precision,

0 commit comments

Comments
 (0)