Skip to content

Commit 010e9ae

Browse files
committed
rebase
1 parent 06acdef commit 010e9ae

File tree

3 files changed

+145
-21
lines changed

3 files changed

+145
-21
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 116 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
CompilationSettings,
5151
UnsupportedOperatorException,
5252
convert_module,
53-
interpret_module,
53+
interpret_module_to_result,
5454
repair_long_or_double_inputs,
5555
)
5656
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
@@ -452,25 +452,108 @@ def convert_method_to_trt_engine(
452452
module: torch.fx.GraphModule,
453453
method_name: str = "forward",
454454
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
455-
device: Device = Device._current_device(),
456-
disable_tf32: bool = False,
457-
sparse_weights: bool = False,
458455
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
459-
refit: bool = False,
460-
debug: bool = False,
461-
capability: _enums.EngineCapability = _enums.EngineCapability.default,
462-
num_avg_timing_iters: int = 1,
463-
workspace_size: int = 0,
464-
dla_sram_size: int = 1048576,
465-
dla_local_dram_size: int = 1073741824,
466-
dla_global_dram_size: int = 536870912,
467-
truncate_long_and_double: int = False,
468-
calibrator: object = None,
469-
allow_shape_tensors: bool = False,
456+
debug: bool = DEBUG,
457+
workspace_size: int = WORKSPACE_SIZE,
458+
min_block_size: int = MIN_BLOCK_SIZE,
459+
torch_executed_ops: Set[str] = set(),
460+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
470461
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
471462
version_compatible: bool = VERSION_COMPATIBLE,
472463
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
464+
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
465+
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
466+
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
467+
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
468+
device: Device = Device._current_device(),
469+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
470+
disable_tf32: bool = DISABLE_TF32,
471+
sparse_weights: bool = SPARSE_WEIGHTS,
472+
refit: bool = REFIT,
473+
engine_capability: EngineCapability = ENGINE_CAPABILITY,
474+
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
475+
dla_sram_size: int = DLA_SRAM_SIZE,
476+
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
477+
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
478+
calibrator: object = None,
479+
allow_shape_tensors: bool = False,
473480
) -> bytes:
481+
"""Convert a GraphModule module method to a serialized TensorRT engine
482+
483+
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
484+
485+
Arguments:
486+
module (torch.fx.GraphModule): Source module
487+
488+
Keyword Args:
489+
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
490+
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
491+
to select device type. ::
492+
493+
input=[
494+
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
495+
torch_tensorrt.Input(
496+
min_shape=(1, 224, 224, 3),
497+
opt_shape=(1, 512, 512, 3),
498+
max_shape=(1, 1024, 1024, 3),
499+
dtype=torch.int32
500+
format=torch.channel_last
501+
), # Dynamic input shape for input #2
502+
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
503+
]
504+
505+
method_name (str): Name of method to convert
506+
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
507+
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
508+
509+
input_signature=([
510+
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
511+
torch_tensorrt.Input(
512+
min_shape=(1, 224, 224, 3),
513+
opt_shape=(1, 512, 512, 3),
514+
max_shape=(1, 1024, 1024, 3),
515+
dtype=torch.int32
516+
format=torch.channel_last
517+
), # Dynamic input shape for input #2
518+
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
519+
520+
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
521+
522+
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
523+
524+
debug (bool): Whether to print out verbose debugging information
525+
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
526+
min_block_size (int): Minimum number of operators per TRT-Engine Block
527+
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
528+
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
529+
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
530+
version_compatible (bool): Provide version forward-compatibility for engine plan files
531+
optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time,
532+
searching for more optimization options. TRT defaults to 3
533+
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
534+
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
535+
argument as None
536+
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
537+
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
538+
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
539+
or only a selected subset of them
540+
device (Device): GPU to compile the model on
541+
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
542+
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
543+
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
544+
sparse_weights (bool): Whether to allow the builder to use sparse weights
545+
refit (bool): Whether to build a refittable engine
546+
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
547+
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
548+
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
549+
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
550+
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
551+
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
552+
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
553+
554+
Returns:
555+
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
556+
"""
474557
if debug:
475558
set_log_level(logger.parent, logging.DEBUG)
476559

@@ -504,18 +587,33 @@ def convert_method_to_trt_engine(
504587
compilation_options = {
505588
"precision": precision,
506589
"debug": debug,
507-
"device": device,
508590
"workspace_size": workspace_size,
509-
"truncate_long_and_double": truncate_long_and_double,
591+
"min_block_size": min_block_size,
592+
"torch_executed_ops": torch_executed_ops,
593+
"pass_through_build_failures": pass_through_build_failures,
510594
"max_aux_streams": max_aux_streams,
511595
"version_compatible": version_compatible,
512596
"optimization_level": optimization_level,
597+
"use_python_runtime": use_python_runtime,
598+
"truncate_long_and_double": truncate_long_and_double,
599+
"use_fast_partitioner": use_fast_partitioner,
600+
"enable_experimental_decompositions": enable_experimental_decompositions,
601+
"device": device,
602+
"require_full_compilation": require_full_compilation,
603+
"disable_tf32": disable_tf32,
604+
"sparse_weights": sparse_weights,
605+
"refit": refit,
606+
"engine_capability": engine_capability,
607+
"num_avg_timing_iters": num_avg_timing_iters,
608+
"dla_sram_size": dla_sram_size,
609+
"dla_local_dram_size": dla_local_dram_size,
610+
"dla_global_dram_size": dla_global_dram_size,
513611
}
514612

515613
settings = CompilationSettings(**compilation_options)
516614
logger.info("Compilation Settings: %s\n", settings)
517615
try:
518-
interpreter_result = interpret_module(module, input_list, settings, method_name)
616+
interpreter_result = interpret_module_to_result(module, input_list, settings)
519617
except UnsupportedOperatorException:
520618
logger.error(
521619
f"Conversion of module {module} not currently fully supported or convertible!",

py/torch_tensorrt/dynamo/conversion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from . import aten_ops_converters, ops_evaluators, prims_ops_converters
2-
from ._conversion import convert_module, interpret_module
2+
from ._conversion import convert_module, interpret_module_to_result
33
from ._ConversionContext import ConversionContext
44
from ._ConverterRegistry import * # noqa: F403
55
from ._TRTInterpreter import * # noqa: F403

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
1616

1717

18-
def interpret_module(
18+
def interpret_module_to_result(
1919
module: torch.fx.GraphModule,
2020
inputs: Sequence[Input],
2121
settings: CompilationSettings = CompilationSettings(),
22-
name: str = "",
2322
) -> TRTInterpreterResult:
23+
"""Interpret an FX module to a TRTInterpreterResult
24+
Args:
25+
module: FX GraphModule to interpret
26+
inputs: Sequence of Tensors representing inputs to the module
27+
settings: Compilation settings
28+
Returns:
29+
TRTInterpreterResult
30+
"""
2431
torch_inputs = get_torch_inputs(inputs, settings.device)
2532
module.to(to_torch_device(settings.device))
2633
module_outputs = module(*torch_inputs)
@@ -47,6 +54,25 @@ def interpret_module(
4754
compilation_settings=settings,
4855
)
4956
interpreter_result = interpreter.run()
57+
return interpreter_result
58+
59+
60+
def convert_module(
61+
module: torch.fx.GraphModule,
62+
inputs: Sequence[Input],
63+
settings: CompilationSettings = CompilationSettings(),
64+
name: str = "",
65+
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
66+
"""Convert an FX module to a TRT module
67+
Args:
68+
module: FX GraphModule to convert
69+
inputs: Sequence of Tensors representing inputs to the module
70+
settings: Compilation settings
71+
name: TRT engine name
72+
Returns:
73+
_PythonTorchTensorRTModule or TorchTensorRTModule
74+
"""
75+
interpreter_result = interpret_module_to_result(module, inputs, settings)
5076

5177
if settings.use_python_runtime:
5278
return PythonTorchTensorRTModule(

0 commit comments

Comments
 (0)