Skip to content

Commit 5f66ade

Browse files
authored
feat: add convert_method_to_trt_engine() for dynamo (#2467)
1 parent fd19353 commit 5f66ade

File tree

6 files changed

+274
-15
lines changed

6 files changed

+274
-15
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,12 @@ def convert_method_to_trt_engine(
319319
"convert_method_to_trt_engine call is not supported for ir=fx"
320320
)
321321
elif target_ir == _IRType.dynamo:
322-
raise RuntimeError(
323-
"convert_method_to_trt_engine call is not supported for ir=dynamo."
322+
return torch_tensorrt.dynamo.convert_module_to_trt_engine( # type: ignore[no-any-return]
323+
module,
324+
inputs=inputs,
325+
method_name=method_name,
326+
enabled_precisions=enabled_precisions_set,
327+
**kwargs,
324328
)
325329
elif target_ir == _IRType.torch_compile:
326330
raise RuntimeError(

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
logger = logging.getLogger(__name__)
88

99
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
10-
from ._compiler import compile
10+
from ._compiler import compile, convert_module_to_trt_engine
1111
from ._exporter import export
1212
from ._settings import CompilationSettings
1313
from ._SourceIR import SourceIR

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from torch.export import ExportedProgram
99
from torch.fx.node import Target
10+
from torch_tensorrt import _enums
1011
from torch_tensorrt._Device import Device
1112
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
1213
EngineCapability,
@@ -47,7 +48,9 @@
4748
)
4849
from torch_tensorrt.dynamo.conversion import (
4950
CompilationSettings,
51+
UnsupportedOperatorException,
5052
convert_module,
53+
interpret_module_to_result,
5154
repair_long_or_double_inputs,
5255
)
5356
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
@@ -443,3 +446,189 @@ def compile_module(
443446
dryrun_stats_display(dryrun_tracker, settings.dryrun)
444447

445448
return partitioned_module
449+
450+
451+
def convert_module_to_trt_engine(
452+
module: torch.fx.GraphModule,
453+
method_name: str = "forward",
454+
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
455+
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
456+
debug: bool = DEBUG,
457+
workspace_size: int = WORKSPACE_SIZE,
458+
min_block_size: int = MIN_BLOCK_SIZE,
459+
torch_executed_ops: Set[str] = set(),
460+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
461+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
462+
version_compatible: bool = VERSION_COMPATIBLE,
463+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
464+
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
465+
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
466+
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
467+
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
468+
device: Device = Device._current_device(),
469+
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
470+
disable_tf32: bool = DISABLE_TF32,
471+
sparse_weights: bool = SPARSE_WEIGHTS,
472+
refit: bool = REFIT,
473+
engine_capability: EngineCapability = ENGINE_CAPABILITY,
474+
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
475+
dla_sram_size: int = DLA_SRAM_SIZE,
476+
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
477+
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
478+
calibrator: object = None,
479+
allow_shape_tensors: bool = False,
480+
) -> bytes:
481+
"""Convert a GraphModule module method to a serialized TensorRT engine
482+
483+
Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings
484+
485+
Arguments:
486+
module (torch.fx.GraphModule): Source module
487+
488+
Keyword Args:
489+
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
490+
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
491+
to select device type. ::
492+
493+
input=[
494+
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
495+
torch_tensorrt.Input(
496+
min_shape=(1, 224, 224, 3),
497+
opt_shape=(1, 512, 512, 3),
498+
max_shape=(1, 1024, 1024, 3),
499+
dtype=torch.int32
500+
format=torch.channel_last
501+
), # Dynamic input shape for input #2
502+
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
503+
]
504+
505+
method_name (str): Name of method to convert
506+
input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
507+
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** ::
508+
509+
input_signature=([
510+
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
511+
torch_tensorrt.Input(
512+
min_shape=(1, 224, 224, 3),
513+
opt_shape=(1, 512, 512, 3),
514+
max_shape=(1, 1024, 1024, 3),
515+
dtype=torch.int32
516+
format=torch.channel_last
517+
), # Dynamic input shape for input #2
518+
], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3
519+
520+
device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on ::
521+
522+
device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True)
523+
524+
debug (bool): Whether to print out verbose debugging information
525+
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
526+
min_block_size (int): Minimum number of operators per TRT-Engine Block
527+
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
528+
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
529+
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
530+
version_compatible (bool): Provide version forward-compatibility for engine plan files
531+
optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time,
532+
searching for more optimization options. TRT defaults to 3
533+
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
534+
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
535+
argument as None
536+
truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32
537+
use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system
538+
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
539+
or only a selected subset of them
540+
device (Device): GPU to compile the model on
541+
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
542+
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
543+
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
544+
sparse_weights (bool): Whether to allow the builder to use sparse weights
545+
refit (bool): Whether to build a refittable engine
546+
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
547+
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
548+
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
549+
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
550+
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
551+
calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration
552+
allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT
553+
554+
Returns:
555+
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
556+
"""
557+
if debug:
558+
set_log_level(logger.parent, logging.DEBUG)
559+
560+
input_list = list(inputs) if inputs is not None else []
561+
# Prepare torch_trt inputs
562+
input_list = prepare_inputs(input_list)
563+
device = to_torch_tensorrt_device(device)
564+
565+
enabled_precisions = (
566+
enabled_precisions if enabled_precisions is not None else {torch.float}
567+
)
568+
569+
if (
570+
torch.float16 in enabled_precisions
571+
or torch_tensorrt.dtype.half in enabled_precisions
572+
):
573+
precision = torch.float16
574+
elif (
575+
torch.float32 in enabled_precisions
576+
or torch_tensorrt.dtype.float in enabled_precisions
577+
):
578+
precision = torch.float32
579+
elif len(enabled_precisions) == 0:
580+
logger.info(f"No precision specified, defaulting to {PRECISION}")
581+
precision = PRECISION
582+
else:
583+
raise ValueError(
584+
f"Precision {enabled_precisions} not supported in the Dynamo Path"
585+
)
586+
587+
compilation_options = {
588+
"precision": precision,
589+
"debug": debug,
590+
"workspace_size": workspace_size,
591+
"min_block_size": min_block_size,
592+
"torch_executed_ops": torch_executed_ops,
593+
"pass_through_build_failures": pass_through_build_failures,
594+
"max_aux_streams": max_aux_streams,
595+
"version_compatible": version_compatible,
596+
"optimization_level": optimization_level,
597+
"use_python_runtime": use_python_runtime,
598+
"truncate_long_and_double": truncate_long_and_double,
599+
"use_fast_partitioner": use_fast_partitioner,
600+
"enable_experimental_decompositions": enable_experimental_decompositions,
601+
"device": device,
602+
"require_full_compilation": require_full_compilation,
603+
"disable_tf32": disable_tf32,
604+
"sparse_weights": sparse_weights,
605+
"refit": refit,
606+
"engine_capability": engine_capability,
607+
"num_avg_timing_iters": num_avg_timing_iters,
608+
"dla_sram_size": dla_sram_size,
609+
"dla_local_dram_size": dla_local_dram_size,
610+
"dla_global_dram_size": dla_global_dram_size,
611+
}
612+
613+
settings = CompilationSettings(**compilation_options)
614+
logger.info("Compilation Settings: %s\n", settings)
615+
try:
616+
interpreter_result = interpret_module_to_result(module, input_list, settings)
617+
except UnsupportedOperatorException:
618+
logger.error(
619+
f"Conversion of module {module} not currently fully supported or convertible!",
620+
exc_info=True,
621+
)
622+
except Exception as e:
623+
logger.error(
624+
f"While interpreting the module got an error: {e}",
625+
exc_info=True,
626+
)
627+
628+
import io
629+
630+
with io.BytesIO() as engine_bytes:
631+
engine_bytes.write(interpreter_result.engine.serialize())
632+
engine_bytearray = engine_bytes.getvalue()
633+
634+
return engine_bytearray

py/torch_tensorrt/dynamo/conversion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from . import aten_ops_converters, ops_evaluators, prims_ops_converters
2-
from ._conversion import convert_module
2+
from ._conversion import convert_module, interpret_module_to_result
33
from ._ConversionContext import ConversionContext
44
from ._ConverterRegistry import * # noqa: F403
55
from ._TRTInterpreter import * # noqa: F403

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,29 @@
77
import torch
88
from torch_tensorrt._Input import Input
99
from torch_tensorrt.dynamo._settings import CompilationSettings
10-
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
10+
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
11+
TRTInterpreter,
12+
TRTInterpreterResult,
13+
)
1114
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
12-
from torch_tensorrt.dynamo.utils import get_torch_inputs
15+
from torch_tensorrt.dynamo.utils import get_torch_inputs, to_torch_device
1316

1417

15-
def convert_module(
18+
def interpret_module_to_result(
1619
module: torch.fx.GraphModule,
1720
inputs: Sequence[Input],
1821
settings: CompilationSettings = CompilationSettings(),
19-
name: str = "",
20-
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
21-
"""Convert an FX module to a TRT module
22+
) -> TRTInterpreterResult:
23+
"""Interpret an FX module to a TRTInterpreterResult
2224
Args:
23-
module: FX GraphModule to convert
25+
module: FX GraphModule to interpret
2426
inputs: Sequence of Tensors representing inputs to the module
2527
settings: Compilation settings
26-
name: TRT engine name
2728
Returns:
28-
_PythonTorchTensorRTModule or TorchTensorRTModule
29+
TRTInterpreterResult
2930
"""
30-
# Specify module output data types to ensure TRT output types agree with
31-
# that of the equivalent Torch module
3231
torch_inputs = get_torch_inputs(inputs, settings.device)
32+
module.to(to_torch_device(settings.device))
3333
module_outputs = module(*torch_inputs)
3434

3535
if not isinstance(module_outputs, (list, tuple)):
@@ -54,6 +54,25 @@ def convert_module(
5454
compilation_settings=settings,
5555
)
5656
interpreter_result = interpreter.run()
57+
return interpreter_result
58+
59+
60+
def convert_module(
61+
module: torch.fx.GraphModule,
62+
inputs: Sequence[Input],
63+
settings: CompilationSettings = CompilationSettings(),
64+
name: str = "",
65+
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
66+
"""Convert an FX module to a TRT module
67+
Args:
68+
module: FX GraphModule to convert
69+
inputs: Sequence of Tensors representing inputs to the module
70+
settings: Compilation settings
71+
name: TRT engine name
72+
Returns:
73+
_PythonTorchTensorRTModule or TorchTensorRTModule
74+
"""
75+
interpreter_result = interpret_module_to_result(module, inputs, settings)
5776

5877
if settings.use_python_runtime:
5978
return PythonTorchTensorRTModule(
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import unittest
2+
3+
import tensorrt as trt
4+
import torch
5+
import torch_tensorrt
6+
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
7+
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
8+
9+
10+
class TestConvertMethodToTrtEngine(unittest.TestCase):
11+
def test_convert_module(self):
12+
class Test(torch.nn.Module):
13+
def forward(self, a, b):
14+
return torch.add(a, b)
15+
16+
# Prepare the input data
17+
input_data_0, input_data_1 = torch.randn((2, 4)), torch.randn((2, 4))
18+
19+
# Create a model
20+
model = Test()
21+
symbolic_traced_gm = torch.fx.symbolic_trace(model)
22+
23+
# Convert to TensorRT engine
24+
trt_engine_str = torch_tensorrt.dynamo.convert_module_to_trt_engine(
25+
symbolic_traced_gm, "forward", inputs=[input_data_0, input_data_1]
26+
)
27+
28+
# Deserialize the TensorRT engine
29+
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
30+
engine = runtime.deserialize_cuda_engine(trt_engine_str)
31+
32+
# Inference on TRT Engine
33+
py_trt_module = PythonTorchTensorRTModule(engine, ["a", "b"], ["output0"])
34+
trt_output = py_trt_module(input_data_0, input_data_1).cpu()
35+
36+
# Inference on PyTorch model
37+
model_output = model(input_data_0, input_data_1)
38+
39+
cos_sim = cosine_similarity(model_output, trt_output)
40+
self.assertTrue(
41+
cos_sim > COSINE_THRESHOLD,
42+
msg=f"TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
43+
)
44+
45+
46+
if __name__ == "__main__":
47+
unittest.main()

0 commit comments

Comments
 (0)