Skip to content

Commit e6da7e4

Browse files
committed
chore: some reorg and internal cleanup, addressing review comments
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent ae453fc commit e6da7e4

File tree

11 files changed

+187
-195
lines changed

11 files changed

+187
-195
lines changed

py/torch_tensorrt/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,11 @@ def _find_lib(name: str, paths: List[str]) -> str:
9292
def _register_with_torch() -> None:
9393
trtorch_dir = os.path.dirname(__file__)
9494
if os.path.isfile(trtorch_dir + "/lib/libtorchtrt.so"):
95-
assert ENABLED_FEATURES.torchscript_frontend == True
96-
assert ENABLED_FEATURES.torch_tensorrt_runtime == True
95+
assert ENABLED_FEATURES.torchscript_frontend
96+
assert ENABLED_FEATURES.torch_tensorrt_runtime
9797
torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt.so")
9898
elif os.path.isfile(trtorch_dir + "/lib/libtorchtrt_runtime.so"):
99-
assert ENABLED_FEATURES.torch_tensorrt_runtime == True
99+
assert ENABLED_FEATURES.torch_tensorrt_runtime
100100
torch.ops.load_library(trtorch_dir + "/lib/libtorchtrt_runtime.so")
101101

102102

py/torch_tensorrt/_compile.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch_tensorrt._enums import dtype
1010
from torch_tensorrt._features import ENABLED_FEATURES
1111
from torch_tensorrt._Input import Input
12+
from torch_tensorrt.dynamo import _defaults
1213
from torch_tensorrt.fx import InputTensorSpec
1314
from torch_tensorrt.fx.lower import compile as fx_compile
1415
from torch_tensorrt.fx.utils import LowerPrecision
@@ -17,6 +18,9 @@
1718
if ENABLED_FEATURES.torchscript_frontend:
1819
import torch_tensorrt.ts
1920
from torch_tensorrt.ts._compiler import compile as torchscript_compile
21+
from torch_tensorrt.ts._compiler import (
22+
convert_method_to_trt_engine as ts_convert_method_to_trt_engine,
23+
)
2024

2125
if ENABLED_FEATURES.dynamo_frontend:
2226
from torch._export import ExportedProgram
@@ -88,20 +92,34 @@ def _get_target_fe(module_type: _ModuleType, ir: str) -> _IRType:
8892
ir_targets_dynamo = ir == "dynamo"
8993
ir_targets_torch_compile = ir == "torch_compile"
9094

91-
if (
92-
module_is_tsable and ir_targets_torchscript
93-
) and ENABLED_FEATURES.torchscript_frontend:
94-
return _IRType.ts
95-
elif (module_is_fxable and ir_targets_fx) and ENABLED_FEATURES.fx_frontend:
96-
return _IRType.fx
97-
elif (
98-
(module_is_fxable or module_is_exportable) and ir_targets_dynamo
99-
) and ENABLED_FEATURES.dynamo_frontend:
100-
return _IRType.dynamo
101-
elif (
102-
module_is_fxable and ir_targets_torch_compile
103-
) and ENABLED_FEATURES.dynamo_frontend:
104-
return _IRType.torch_compile
95+
if module_is_tsable and ir_targets_torchscript:
96+
if ENABLED_FEATURES.torchscript_frontend:
97+
return _IRType.ts
98+
else:
99+
raise ValueError(
100+
"Requested using the TS frontend but the TS frontend is not available in this build of Torch-TensorRT"
101+
)
102+
elif module_is_fxable and ir_targets_fx:
103+
if ENABLED_FEATURES.fx_frontend:
104+
return _IRType.fx
105+
else:
106+
raise ValueError(
107+
"Requested using the FX frontend but the FX frontend is not available in this build of Torch-TensorRT"
108+
)
109+
elif (module_is_fxable or module_is_exportable) and ir_targets_dynamo:
110+
if ENABLED_FEATURES.dynamo_frontend:
111+
return _IRType.dynamo
112+
else:
113+
raise ValueError(
114+
"Requested using the Dynamo frontend but the Dynamo frontend is not available in this build of Torch-TensorRT"
115+
)
116+
elif module_is_fxable and ir_targets_torch_compile:
117+
if ENABLED_FEATURES.dynamo_frontend:
118+
return _IRType.torch_compile
119+
else:
120+
raise ValueError(
121+
"Requested using the Torch-TensorRT torch.compile backend but the Torch-TensorRT torch.compile backend is not available in this build of Torch-TensorRT"
122+
)
105123
else:
106124
if ir == "default":
107125
# Options are listed in order of preference
@@ -169,9 +187,9 @@ def compile(
169187
Returns:
170188
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
171189
"""
172-
input_list = inputs if inputs is not None else []
190+
input_list = inputs if inputs else []
173191
enabled_precisions_set: Set[dtype | torch.dtype] = (
174-
enabled_precisions if enabled_precisions is not None else {dtype.float}
192+
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
175193
)
176194

177195
module_type = _parse_module_type(module)
@@ -309,13 +327,14 @@ def convert_method_to_trt_engine(
309327
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
310328
)
311329
ts_mod = torch.jit.script(module)
312-
return torch_tensorrt.ts.convert_method_to_trt_engine(
330+
serialized_engine: bytes = ts_convert_method_to_trt_engine(
313331
ts_mod,
314332
inputs=inputs,
315333
method_name=method_name,
316334
enabled_precisions=enabled_precisions_set,
317335
**kwargs,
318336
)
337+
return serialized_engine
319338
elif target_ir == _IRType.fx:
320339
raise RuntimeError(
321340
"convert_method_to_trt_engine call is not supported for ir=fx"

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 53 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,7 @@
1010
from torch_tensorrt._Device import Device
1111
from torch_tensorrt._enums import EngineCapability, dtype
1212
from torch_tensorrt._Input import Input
13-
from torch_tensorrt.dynamo import partitioning
14-
from torch_tensorrt.dynamo._defaults import (
15-
DEBUG,
16-
DEVICE,
17-
DISABLE_TF32,
18-
DLA_GLOBAL_DRAM_SIZE,
19-
DLA_LOCAL_DRAM_SIZE,
20-
DLA_SRAM_SIZE,
21-
DRYRUN,
22-
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
23-
ENGINE_CAPABILITY,
24-
HARDWARE_COMPATIBLE,
25-
MAX_AUX_STREAMS,
26-
MIN_BLOCK_SIZE,
27-
NUM_AVG_TIMING_ITERS,
28-
OPTIMIZATION_LEVEL,
29-
OUTPUT_FORMAT,
30-
PASS_THROUGH_BUILD_FAILURES,
31-
PRECISION,
32-
REFIT,
33-
REQUIRE_FULL_COMPILATION,
34-
SPARSE_WEIGHTS,
35-
TRUNCATE_LONG_AND_DOUBLE,
36-
USE_FAST_PARTITIONER,
37-
USE_PYTHON_RUNTIME,
38-
VERSION_COMPATIBLE,
39-
WORKSPACE_SIZE,
40-
)
13+
from torch_tensorrt.dynamo import _defaults, partitioning
4114
from torch_tensorrt.dynamo._DryRunTracker import (
4215
DryRunTracker,
4316
PerSubgraphData,
@@ -72,35 +45,35 @@ def compile(
7245
exported_program: ExportedProgram,
7346
inputs: Tuple[Any, ...],
7447
*,
75-
device: Optional[Union[Device, torch.device, str]] = DEVICE,
76-
disable_tf32: bool = DISABLE_TF32,
77-
sparse_weights: bool = SPARSE_WEIGHTS,
78-
enabled_precisions: Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] = (
79-
dtype.float32,
80-
),
81-
engine_capability: EngineCapability = ENGINE_CAPABILITY,
82-
refit: bool = REFIT,
83-
debug: bool = DEBUG,
84-
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
85-
workspace_size: int = WORKSPACE_SIZE,
86-
dla_sram_size: int = DLA_SRAM_SIZE,
87-
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
88-
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
89-
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
90-
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
91-
min_block_size: int = MIN_BLOCK_SIZE,
48+
device: Optional[Union[Device, torch.device, str]] = _defaults.DEVICE,
49+
disable_tf32: bool = _defaults.DISABLE_TF32,
50+
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
51+
enabled_precisions: (
52+
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
53+
) = _defaults.ENABLED_PRECISIONS,
54+
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
55+
refit: bool = _defaults.REFIT,
56+
debug: bool = _defaults.DEBUG,
57+
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
58+
workspace_size: int = _defaults.WORKSPACE_SIZE,
59+
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
60+
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
61+
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
62+
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
63+
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
64+
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
9265
torch_executed_ops: Optional[Collection[Target]] = None,
9366
torch_executed_modules: Optional[List[str]] = None,
94-
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
95-
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
96-
version_compatible: bool = VERSION_COMPATIBLE,
97-
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
98-
use_python_runtime: bool = USE_PYTHON_RUNTIME,
99-
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
100-
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
101-
dryrun: bool = DRYRUN,
102-
hardware_compatible: bool = HARDWARE_COMPATIBLE,
103-
output_format: str = OUTPUT_FORMAT,
67+
pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES,
68+
max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS,
69+
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
70+
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
71+
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
72+
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
73+
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
74+
dryrun: bool = _defaults.DRYRUN,
75+
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
76+
output_format: str = _defaults.OUTPUT_FORMAT,
10477
**kwargs: Any,
10578
) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]:
10679
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -182,6 +155,7 @@ def compile(
182155
# Prepare torch_trt inputs
183156
inputs = prepare_inputs(inputs)
184157
device = to_torch_tensorrt_device(device)
158+
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
185159

186160
if not isinstance(exported_program, ExportedProgram):
187161
raise AssertionError(
@@ -198,21 +172,10 @@ def compile(
198172
gm = apply_lowering_passes(gm, torch_inputs)
199173
logger.debug("Lowered Input graph: " + str(gm.graph))
200174

201-
if dtype.float16 in enabled_precisions or dtype.half in enabled_precisions:
202-
precision = dtype.float16
203-
elif dtype.float32 in enabled_precisions or dtype.float in enabled_precisions:
204-
precision = dtype.float32
205-
elif len(enabled_precisions) == 0:
206-
logger.info(f"No precision specified, defaulting to {PRECISION}")
207-
precision = PRECISION
208-
else:
209-
raise ValueError(
210-
f"Precision {enabled_precisions} not supported in the Dynamo Path"
211-
)
212-
enabled_precisions = {dtype._from(e) for e in enabled_precisions}
213-
214175
compilation_options = {
215-
"precision": precision,
176+
"enabled_precisions": (
177+
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
178+
),
216179
"debug": debug,
217180
"device": device,
218181
"workspace_size": workspace_size,
@@ -459,28 +422,28 @@ def convert_module_to_trt_engine(
459422
enabled_precisions: Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype] = (
460423
dtype.float32,
461424
),
462-
debug: bool = DEBUG,
463-
workspace_size: int = WORKSPACE_SIZE,
464-
min_block_size: int = MIN_BLOCK_SIZE,
425+
debug: bool = _defaults.DEBUG,
426+
workspace_size: int = _defaults.WORKSPACE_SIZE,
427+
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
465428
torch_executed_ops: Optional[Set[str]] = None,
466-
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
467-
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
468-
version_compatible: bool = VERSION_COMPATIBLE,
469-
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
470-
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
471-
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
472-
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
473-
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
429+
pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES,
430+
max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS,
431+
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
432+
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
433+
use_python_runtime: Optional[bool] = _defaults.USE_PYTHON_RUNTIME,
434+
truncate_long_and_double: bool = _defaults.TRUNCATE_LONG_AND_DOUBLE,
435+
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
436+
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
474437
device: Device = Device._current_device(),
475-
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
476-
disable_tf32: bool = DISABLE_TF32,
477-
sparse_weights: bool = SPARSE_WEIGHTS,
478-
refit: bool = REFIT,
479-
engine_capability: EngineCapability = ENGINE_CAPABILITY,
480-
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
481-
dla_sram_size: int = DLA_SRAM_SIZE,
482-
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
483-
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
438+
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
439+
disable_tf32: bool = _defaults.DISABLE_TF32,
440+
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
441+
refit: bool = _defaults.REFIT,
442+
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
443+
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
444+
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
445+
dla_local_dram_size: int = _defaults.DLA_LOCAL_DRAM_SIZE,
446+
dla_global_dram_size: int = _defaults.DLA_GLOBAL_DRAM_SIZE,
484447
calibrator: object = None,
485448
allow_shape_tensors: bool = False,
486449
) -> bytes:
@@ -569,22 +532,10 @@ def convert_module_to_trt_engine(
569532
input_list = prepare_inputs(input_list)
570533
device = to_torch_tensorrt_device(device)
571534

572-
if dtype.float16 in enabled_precisions or dtype.half in enabled_precisions:
573-
precision = dtype.float16
574-
elif dtype.float32 in enabled_precisions or dtype.float in enabled_precisions:
575-
precision = dtype.float32
576-
elif len(enabled_precisions) == 0:
577-
logger.info(f"No precision specified, defaulting to {PRECISION}")
578-
precision = PRECISION
579-
else:
580-
raise ValueError(
581-
f"Precision {enabled_precisions} not supported in the Dynamo Path"
582-
)
583-
584535
enabled_precisions = {dtype._from(e) for e in enabled_precisions}
585536

586537
compilation_options = {
587-
"precision": precision,
538+
"enabled_precisions": enabled_precisions,
588539
"debug": debug,
589540
"workspace_size": workspace_size,
590541
"min_block_size": min_block_size,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch_tensorrt._Device import Device
33
from torch_tensorrt._enums import EngineCapability, dtype
44

5-
PRECISION = dtype.float32
5+
ENABLED_PRECISIONS = {dtype.f32}
66
DEBUG = False
77
DEVICE = None
88
DISABLE_TF32 = False
@@ -27,6 +27,7 @@
2727
DRYRUN = False
2828
HARDWARE_COMPATIBLE = False
2929
OUTPUT_FORMAT = "exported_program"
30+
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8}
3031

3132

3233
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
DLA_SRAM_SIZE,
1313
DRYRUN,
1414
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
15+
ENABLED_PRECISIONS,
1516
ENGINE_CAPABILITY,
1617
HARDWARE_COMPATIBLE,
1718
MAX_AUX_STREAMS,
@@ -20,7 +21,6 @@
2021
OPTIMIZATION_LEVEL,
2122
OUTPUT_FORMAT,
2223
PASS_THROUGH_BUILD_FAILURES,
23-
PRECISION,
2424
REFIT,
2525
REQUIRE_FULL_COMPILATION,
2626
SPARSE_WEIGHTS,
@@ -38,7 +38,7 @@ class CompilationSettings:
3838
"""Compilation settings for Torch-TensorRT Dynamo Paths
3939
4040
Args:
41-
precision (torch.dtype): Model Layer precision
41+
enabled_precisions (Set[dtype]): Available kernel dtype precisions
4242
debug (bool): Whether to print out verbose debugging information
4343
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
4444
min_block_size (int): Minimum number of operators per TRT-Engine Block
@@ -73,7 +73,7 @@ class CompilationSettings:
7373
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
7474
"""
7575

76-
precision: dtype = field(default_factory=lambda: PRECISION)
76+
enabled_precisions: dtype = field(default_factory=lambda: ENABLED_PRECISIONS)
7777
debug: bool = DEBUG
7878
workspace_size: int = WORKSPACE_SIZE
7979
min_block_size: int = MIN_BLOCK_SIZE

0 commit comments

Comments
 (0)