Skip to content

Commit f96d3e5

Browse files
committed
feat: add convert_method_to_trt_engine() for dynamo
1 parent b8403b8 commit f96d3e5

File tree

3 files changed

+129
-2
lines changed

3 files changed

+129
-2
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,12 @@ def convert_method_to_trt_engine(
319319
"convert_method_to_trt_engine call is not supported for ir=fx"
320320
)
321321
elif target_ir == _IRType.dynamo:
322-
raise RuntimeError(
323-
"convert_method_to_trt_engine call is not supported for ir=dynamo."
322+
return torch_tensorrt.dynamo.convert_method_to_trt_engine( # type: ignore[no-any-return]
323+
module,
324+
inputs=inputs,
325+
method_name=method_name,
326+
enabled_precisions=enabled_precisions_set,
327+
**kwargs,
324328
)
325329
elif target_ir == _IRType.torch_compile:
326330
raise RuntimeError(

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212
from ._settings import CompilationSettings
1313
from ._SourceIR import SourceIR
1414
from ._tracer import trace
15+
16+
from torch_tensorrt.dynamo._compiler import * # noqa: F403

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import logging
55
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
66

7+
import tensorrt as trt
78
import torch
89
from torch.export import ExportedProgram
910
from torch.fx.node import Target
11+
from torch_tensorrt import _enums
1012
from torch_tensorrt._Device import Device
1113
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
1214
EngineCapability,
@@ -443,3 +445,122 @@ def compile_module(
443445
dryrun_stats_display(dryrun_tracker, settings.dryrun)
444446

445447
return partitioned_module
448+
449+
450+
def interpreter(
451+
module: torch.fx.GraphModule,
452+
inputs: Sequence[Input],
453+
settings: CompilationSettings = CompilationSettings(),
454+
name: str = "",
455+
) -> TRTInterpreterResult:
456+
torch_inputs = get_torch_inputs(inputs, settings.device)
457+
module_outputs = module(*torch_inputs)
458+
459+
if not isinstance(module_outputs, (list, tuple)):
460+
module_outputs = [module_outputs]
461+
462+
# Int64 outputs can sometimes be generated from within other operators
463+
# such as aten.sum - such outputs can be truncated
464+
output_dtypes = []
465+
for output in module_outputs:
466+
if settings.truncate_long_and_double and output.dtype == torch.float64:
467+
output_dtypes.append(torch.float32)
468+
elif settings.truncate_long_and_double and output.dtype == torch.int64:
469+
output_dtypes.append(torch.int32)
470+
else:
471+
output_dtypes.append(output.dtype)
472+
473+
interpreter = TRTInterpreter(
474+
module,
475+
inputs,
476+
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
477+
output_dtypes=output_dtypes,
478+
compilation_settings=settings,
479+
)
480+
interpreter_result = interpreter.run(
481+
workspace_size=settings.workspace_size,
482+
precision=settings.precision,
483+
profiling_verbosity=(
484+
trt.ProfilingVerbosity.VERBOSE
485+
if settings.debug
486+
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
487+
),
488+
max_aux_streams=settings.max_aux_streams,
489+
version_compatible=settings.version_compatible,
490+
optimization_level=settings.optimization_level,
491+
)
492+
return interpreter_result
493+
494+
495+
def convert_method_to_trt_engine(
496+
module: torch.fx.GraphModule,
497+
method_name: str = "forward",
498+
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
499+
device: Device = Device._current_device(),
500+
disable_tf32: bool = False,
501+
sparse_weights: bool = False,
502+
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
503+
refit: bool = False,
504+
debug: bool = False,
505+
capability: _enums.EngineCapability = _enums.EngineCapability.default,
506+
num_avg_timing_iters: int = 1,
507+
workspace_size: int = 0,
508+
dla_sram_size: int = 1048576,
509+
dla_local_dram_size: int = 1073741824,
510+
dla_global_dram_size: int = 536870912,
511+
truncate_long_and_double: int = False,
512+
calibrator: object = None,
513+
allow_shape_tensors: bool = False,
514+
) -> bytes:
515+
if debug:
516+
set_log_level(logger.parent, logging.DEBUG)
517+
518+
input_list = list(inputs) if inputs is not None else []
519+
# Prepare torch_trt inputs
520+
input_list = prepare_inputs(input_list)
521+
device = to_torch_tensorrt_device(device)
522+
523+
enabled_precisions = (
524+
enabled_precisions if enabled_precisions is not None else {torch.float}
525+
)
526+
527+
if (
528+
torch.float16 in enabled_precisions
529+
or torch_tensorrt.dtype.half in enabled_precisions
530+
):
531+
precision = torch.float16
532+
elif (
533+
torch.float32 in enabled_precisions
534+
or torch_tensorrt.dtype.float in enabled_precisions
535+
):
536+
precision = torch.float32
537+
elif len(enabled_precisions) == 0:
538+
logger.info(f"No precision specified, defaulting to {PRECISION}")
539+
precision = PRECISION
540+
else:
541+
raise ValueError(
542+
f"Precision {enabled_precisions} not supported in the Dynamo Path"
543+
)
544+
545+
compilation_options = {
546+
"precision": precision,
547+
"debug": debug,
548+
"device": device,
549+
"workspace_size": workspace_size,
550+
"truncate_long_and_double": truncate_long_and_double,
551+
"max_aux_streams": MAX_AUX_STREAMS,
552+
"version_compatible": VERSION_COMPATIBLE,
553+
"optimization_level": OPTIMIZATION_LEVEL,
554+
}
555+
556+
settings = CompilationSettings(**compilation_options)
557+
logger.info("Compilation Settings: %s\n", settings)
558+
interpreter_result = interpreter(module, input_list, settings, method_name)
559+
560+
import io
561+
562+
with io.BytesIO() as engine_bytes:
563+
engine_bytes.write(interpreter_result.engine.serialize())
564+
engine_bytearray = engine_bytes.getvalue()
565+
566+
return engine_bytearray

0 commit comments

Comments
 (0)