Skip to content

Commit 38caf26

Browse files
committed
rebase
1 parent 6507e5a commit 38caf26

File tree

3 files changed

+146
-21
lines changed

3 files changed

+146
-21
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 117 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
from dataclasses import field
56
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
67

78
import torch
@@ -42,7 +43,7 @@
4243
CompilationSettings,
4344
UnsupportedOperatorException,
4445
convert_module,
45-
interpret_module,
46+
interpret_module_to_result,
4647
repair_long_or_double_inputs,
4748
)
4849
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
@@ -351,25 +352,108 @@ def convert_method_to_trt_engine(
351352
module: torch.fx.GraphModule,
352353
method_name: str = "forward",
353354
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
354-
device: Device = Device._current_device(),
355-
disable_tf32: bool = False,
356-
sparse_weights: bool = False,
357355
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
358-
refit: bool = False,
359-
debug: bool = False,
360-
capability: _enums.EngineCapability = _enums.EngineCapability.default,
361-
num_avg_timing_iters: int = 1,
362-
workspace_size: int = 0,
363-
dla_sram_size: int = 1048576,
364-
dla_local_dram_size: int = 1073741824,
365-
dla_global_dram_size: int = 536870912,
366-
truncate_long_and_double: int = False,
367-
calibrator: object = None,
368-
allow_shape_tensors: bool = False,
356+
debug: bool = DEBUG,
357+
workspace_size: int = WORKSPACE_SIZE,
358+
min_block_size: int = MIN_BLOCK_SIZE,
359+
torch_executed_ops: Set[str] = field(default_factory=set),
360+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
369361
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
370362
version_compatible: bool = VERSION_COMPATIBLE,
371363
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
364+
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
365+
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
366+
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
367+
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
368+
device: Device = Device._current_device(),
369+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
370+
disable_tf32: bool = DISABLE_TF32,
371+
sparse_weights: bool = SPARSE_WEIGHTS,
372+
refit: bool = REFIT,
373+
engine_capability: EngineCapability = ENGINE_CAPABILITY,
374+
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
375+
dla_sram_size: int = DLA_SRAM_SIZE,
376+
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
377+
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
378+
calibrator: object = None,
379+
allow_shape_tensors: bool = False,
372380
) -> bytes:
381+
"""Convert a GraphModule module method to a serialized TensorRT engine
382+
383+
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
384+
385+
Arguments:
386+
module (torch.fx.GraphModule): Source module
387+
388+
Keyword Args:
389+
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
390+
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
391+
to select device type. ::
392+
393+
input=[
394+
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
395+
torch_tensorrt.Input(
396+
min_shape=(1, 224, 224, 3),
397+
opt_shape=(1, 512, 512, 3),
398+
max_shape=(1, 1024, 1024, 3),
399+
dtype=torch.int32
400+
format=torch.channel_last
401+
), # Dynamic input shape for input #2
402+
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
403+
]
404+
405+
method_name (str): Name of method to convert
406+
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
407+
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
408+
409+
input_signature=([
410+
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
411+
torch_tensorrt.Input(
412+
min_shape=(1, 224, 224, 3),
413+
opt_shape=(1, 512, 512, 3),
414+
max_shape=(1, 1024, 1024, 3),
415+
dtype=torch.int32
416+
format=torch.channel_last
417+
), # Dynamic input shape for input #2
418+
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
419+
420+
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
421+
422+
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
423+
424+
debug (bool): Whether to print out verbose debugging information
425+
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
426+
min_block_size (int): Minimum number of operators per TRT-Engine Block
427+
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
428+
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
429+
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
430+
version_compatible (bool): Provide version forward-compatibility for engine plan files
431+
optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time,
432+
searching for more optimization options. TRT defaults to 3
433+
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
434+
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
435+
argument as None
436+
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
437+
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
438+
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
439+
or only a selected subset of them
440+
device (Device): GPU to compile the model on
441+
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
442+
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
443+
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
444+
sparse_weights (bool): Whether to allow the builder to use sparse weights
445+
refit (bool): Whether to build a refittable engine
446+
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
447+
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
448+
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
449+
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
450+
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
451+
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
452+
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
453+
454+
Returns:
455+
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
456+
"""
373457
if debug:
374458
set_log_level(logger.parent, logging.DEBUG)
375459

@@ -403,18 +487,33 @@ def convert_method_to_trt_engine(
403487
compilation_options = {
404488
"precision": precision,
405489
"debug": debug,
406-
"device": device,
407490
"workspace_size": workspace_size,
408-
"truncate_long_and_double": truncate_long_and_double,
491+
"min_block_size": min_block_size,
492+
"torch_executed_ops": torch_executed_ops,
493+
"pass_through_build_failures": pass_through_build_failures,
409494
"max_aux_streams": max_aux_streams,
410495
"version_compatible": version_compatible,
411496
"optimization_level": optimization_level,
497+
"use_python_runtime": use_python_runtime,
498+
"truncate_long_and_double": truncate_long_and_double,
499+
"use_fast_partitioner": use_fast_partitioner,
500+
"enable_experimental_decompositions": enable_experimental_decompositions,
501+
"device": device,
502+
"require_full_compilation": require_full_compilation,
503+
"disable_tf32": disable_tf32,
504+
"sparse_weights": sparse_weights,
505+
"refit": refit,
506+
"engine_capability": engine_capability,
507+
"num_avg_timing_iters": num_avg_timing_iters,
508+
"dla_sram_size": dla_sram_size,
509+
"dla_local_dram_size": dla_local_dram_size,
510+
"dla_global_dram_size": dla_global_dram_size,
412511
}
413512

414513
settings = CompilationSettings(**compilation_options)
415514
logger.info("Compilation Settings: %s\n", settings)
416515
try:
417-
interpreter_result = interpret_module(module, input_list, settings, method_name)
516+
interpreter_result = interpret_module_to_result(module, input_list, settings)
418517
except UnsupportedOperatorException:
419518
logger.error(
420519
f"Conversion of module {module} not currently fully supported or convertible!",

py/torch_tensorrt/dynamo/conversion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from . import aten_ops_converters, ops_evaluators, prims_ops_converters
2-
from ._conversion import convert_module, interpret_module
2+
from ._conversion import convert_module, interpret_module_to_result
33
from ._ConversionContext import ConversionContext
44
from ._ConverterRegistry import * # noqa: F403
55
from ._TRTInterpreter import * # noqa: F403

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
1616

1717

18-
def interpret_module(
18+
def interpret_module_to_result(
1919
module: torch.fx.GraphModule,
2020
inputs: Sequence[Input],
2121
settings: CompilationSettings = CompilationSettings(),
22-
name: str = "",
2322
) -> TRTInterpreterResult:
23+
"""Interpret an FX module to a TRTInterpreterResult
24+
Args:
25+
module: FX GraphModule to interpret
26+
inputs: Sequence of Tensors representing inputs to the module
27+
settings: Compilation settings
28+
Returns:
29+
TRTInterpreterResult
30+
"""
2431
torch_inputs = get_torch_inputs(inputs, settings.device)
2532
module.to(to_torch_device(settings.device))
2633
module_outputs = module(*torch_inputs)
@@ -47,6 +54,25 @@ def interpret_module(
4754
compilation_settings=settings,
4855
)
4956
interpreter_result = interpreter.run()
57+
return interpreter_result
58+
59+
60+
def convert_module(
61+
module: torch.fx.GraphModule,
62+
inputs: Sequence[Input],
63+
settings: CompilationSettings = CompilationSettings(),
64+
name: str = "",
65+
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
66+
"""Convert an FX module to a TRT module
67+
Args:
68+
module: FX GraphModule to convert
69+
inputs: Sequence of Tensors representing inputs to the module
70+
settings: Compilation settings
71+
name: TRT engine name
72+
Returns:
73+
_PythonTorchTensorRTModule or TorchTensorRTModule
74+
"""
75+
interpreter_result = interpret_module_to_result(module, inputs, settings)
5076

5177
if settings.use_python_runtime:
5278
return PythonTorchTensorRTModule(

0 commit comments

Comments
 (0)