Skip to content

Commit aa0dda8

Browse files
committed
fix: Move key functions, fix bugs
- Improve overall functionality, fix bugs - Move functions into __init__.py - Improve overall documentation, comments, function header typing, and code organization
1 parent eea3884 commit aa0dda8

File tree

4 files changed

+146
-137
lines changed

4 files changed

+146
-137
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,119 @@
1-
from _compiler import compile
2-
from backends import create_backend
1+
import torch
2+
import logging
3+
import torch_tensorrt
4+
from functools import partial
5+
6+
from typing import Sequence, Any
7+
from torch_tensorrt import EngineCapability, Device
8+
from torch_tensorrt.fx.utils import LowerPrecision
9+
10+
from torch_tensorrt.dynamo._settings import CompilationSettings
11+
from torch_tensorrt.dynamo.backends import tensorrt_backend
12+
from torch_tensorrt.dynamo._defaults import (
13+
PRECISION,
14+
DEBUG,
15+
MAX_WORKSPACE_SIZE,
16+
MAX_NUM_TRT_ENGINES,
17+
)
18+
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
def compile(
24+
gm: torch.nn.Module,
25+
example_inputs: Sequence[Any],
26+
*,
27+
device=Device._current_device(),
28+
disable_tf32=False,
29+
sparse_weights=False,
30+
enabled_precisions=set(),
31+
refit=False,
32+
debug=DEBUG,
33+
capability=EngineCapability.default,
34+
num_avg_timing_iters=1,
35+
workspace_size=MAX_WORKSPACE_SIZE,
36+
dla_sram_size=1048576,
37+
dla_local_dram_size=1073741824,
38+
dla_global_dram_size=536870912,
39+
calibrator=None,
40+
truncate_long_and_double=False,
41+
require_full_compilation=False,
42+
min_block_size=3,
43+
torch_executed_ops=[],
44+
torch_executed_modules=[],
45+
**kwargs,
46+
):
47+
48+
logger.warn(
49+
"The Dynamo backend is an experimental feature, for which only the "
50+
+ "following arguments are supported: "
51+
+ "{enabled_precisions, debug, workspace_size, max_num_trt_engines}"
52+
)
53+
54+
if (
55+
torch.float16 in enabled_precisions
56+
or torch_tensorrt.dtype.half in enabled_precisions
57+
):
58+
lower_precision = LowerPrecision.FP16
59+
elif (
60+
torch.float32 in enabled_precisions
61+
or torch_tensorrt.dtype.float in enabled_precisions
62+
):
63+
lower_precision = LowerPrecision.FP32
64+
elif len(enabled_precisions) == 0:
65+
logger.info(f"No precision specified, defaulting to {PRECISION}")
66+
lower_precision = PRECISION
67+
else:
68+
raise ValueError(
69+
f"Precision {enabled_precisions} not supported in the Dynamo Path"
70+
)
71+
72+
custom_backend = create_backend(
73+
precision=lower_precision,
74+
debug=debug,
75+
workspace_size=workspace_size,
76+
**kwargs,
77+
)
78+
79+
model = torch.compile(gm, backend=custom_backend)
80+
81+
# Ensure compilation occurs by calling the function with provided inputs
82+
model(*example_inputs)
83+
84+
return model
85+
86+
87+
from torch_tensorrt.fx.utils import LowerPrecision
88+
89+
logger = logging.getLogger(__name__)
90+
91+
92+
def create_backend(
93+
precision: LowerPrecision = PRECISION,
94+
debug: bool = DEBUG,
95+
workspace_size: int = MAX_WORKSPACE_SIZE,
96+
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES,
97+
**kwargs,
98+
):
99+
"""Create torch.compile backend given specified arguments
100+
101+
Args:
102+
precision:
103+
debug: Whether to print out verbose debugging information
104+
workspace_size: Maximum workspace TRT is allowed to use for the module
105+
precision: Model Layer precision
106+
Returns:
107+
Backend for torch.compile
108+
"""
109+
settings = CompilationSettings(
110+
debug=debug,
111+
precision=precision,
112+
workspace_size=workspace_size,
113+
max_num_trt_engines=max_num_trt_engines,
114+
)
115+
116+
return partial(
117+
tensorrt_backend,
118+
settings=settings,
119+
)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

py/torch_tensorrt/dynamo/backends.py

Lines changed: 14 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
1+
from typing import Sequence
12
import torch
2-
import logging
33
import traceback
44
from functools import partial
55
import torch._dynamo as td
66

7-
from torch_tensorrt.dynamo._defaults import (
8-
PRECISION,
9-
DEBUG,
10-
MAX_WORKSPACE_SIZE,
11-
MAX_NUM_TRT_ENGINES,
12-
)
137
from torch_tensorrt.dynamo._settings import CompilationSettings
148
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
159
from torch_tensorrt.dynamo.lowering._partition import partition, get_submod_inputs
@@ -19,49 +13,14 @@
1913

2014
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2115

22-
from torch_tensorrt.fx.utils import LowerPrecision
23-
24-
logger = logging.getLogger(__name__)
25-
26-
27-
def create_backend(
28-
precision: LowerPrecision = PRECISION,
29-
debug: bool = DEBUG,
30-
workspace_size: int = MAX_WORKSPACE_SIZE,
31-
max_num_trt_engines: int = MAX_NUM_TRT_ENGINES,
32-
**kwargs
33-
):
34-
"""Create torch.compile backend given specified arguments
35-
36-
Args:
37-
precision:
38-
debug: Whether to print out verbose debugging information
39-
workspace_size: Maximum workspace TRT is allowed to use for the module
40-
precision: Model Layer precision
41-
Returns:
42-
Backend for torch.compile
43-
"""
44-
settings = CompilationSettings(
45-
debug=debug,
46-
precision=precision,
47-
workspace_size=workspace_size,
48-
max_num_trt_engines=max_num_trt_engines,
49-
)
50-
51-
return partial(
52-
tensorrt_backend,
53-
settings=settings,
54-
)
55-
5616

5717
@td.register_backend(name="tensorrt")
5818
@fake_tensor_unsupported
5919
def tensorrt_backend(
60-
gm: torch.Module,
61-
sample_inputs,
20+
gm: torch.nn.Module,
21+
sample_inputs: Sequence[torch.Tensor],
6222
settings: CompilationSettings = CompilationSettings(),
6323
):
64-
6524
custom_backend = partial(
6625
fx_dynamo_backend,
6726
settings=settings,
@@ -80,10 +39,18 @@ def tensorrt_backend(
8039
@fake_tensor_unsupported
8140
def fx_dynamo_backend(
8241
gm: torch.fx.GraphModule,
83-
example_inputs,
42+
example_inputs: Sequence[torch.Tensor],
8443
settings: CompilationSettings = CompilationSettings(),
8544
):
86-
"""Helper function to manage translation of FX module to TRT engines"""
45+
"""Helper function to manage translation of FX module to TRT engines
46+
47+
Args:
48+
module: FX GraphModule to convert
49+
inputs: Inputs to the module
50+
settings: Compilation settings
51+
Returns:
52+
Compiled FX GraphModule
53+
"""
8754
try:
8855
trt_compiled = compile_module(
8956
gm,
@@ -102,7 +69,7 @@ def fx_dynamo_backend(
10269

10370
def compile_module(
10471
gm: torch.fx.GraphModule,
105-
example_inputs,
72+
example_inputs: Sequence[torch.Tensor],
10673
settings: CompilationSettings = CompilationSettings(),
10774
) -> torch.fx.GraphModule:
10875
"""Compile an FX module

py/torch_tensorrt/dynamo/lowering/_partition.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,19 @@ def is_node_supported(
3838

3939
def print_support_overview(self, num_trt_blocks: Optional[int] = None):
4040
if num_trt_blocks is not None:
41-
print(f"Number of TensorRT-Accelerated Subgraphs: {num_trt_blocks}\n")
41+
print(f"\nNumber of TensorRT-Accelerated Subgraphs: {num_trt_blocks}")
4242

43-
print("Supported Nodes:")
43+
print("\nSupported Nodes:")
4444
for node_name in self.supported_operators:
45-
print(node_name)
45+
print("-", node_name)
4646

47-
print("\nUnsupported Nodes:")
48-
for node_name in self.unsupported_operators:
49-
print(node_name)
47+
if len(self.unsupported_operators) != 0:
48+
print("\nUnsupported Nodes:")
49+
for node_name in self.unsupported_operators:
50+
print("-", node_name)
51+
print("\n")
52+
else:
53+
print("\nAll Nodes Supported\n")
5054

5155

5256
def partition(
@@ -88,7 +92,9 @@ def partition(
8892

8993

9094
def get_submod_inputs(
91-
mod: torch.fx.GraphModule, submod: torch.fx.GraphModule, inputs
95+
mod: torch.fx.GraphModule,
96+
submod: torch.fx.GraphModule,
97+
inputs: Sequence[torch.Tensor],
9298
) -> Sequence[torch.Tensor]:
9399
"""Helper function to get inputs to a Torch submodule
94100

0 commit comments

Comments
 (0)