Skip to content

Commit d34966e

Browse files
committed
feat: add convert_method_to_trt_engine() for dynamo
1 parent cd158b6 commit d34966e

File tree

3 files changed

+133
-2
lines changed

3 files changed

+133
-2
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,12 @@ def convert_method_to_trt_engine(
319319
"convert_method_to_trt_engine call is not supported for ir=fx"
320320
)
321321
elif target_ir == _IRType.dynamo:
322-
raise RuntimeError(
323-
"convert_method_to_trt_engine call is not supported for ir=dynamo."
322+
return torch_tensorrt.dynamo.convert_method_to_trt_engine( # type: ignore[no-any-return]
323+
module,
324+
inputs=inputs,
325+
method_name=method_name,
326+
enabled_precisions=enabled_precisions_set,
327+
**kwargs,
324328
)
325329
elif target_ir == _IRType.torch_compile:
326330
raise RuntimeError(

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212
from ._settings import CompilationSettings
1313
from ._SourceIR import SourceIR
1414
from ._tracer import trace
15+
16+
from torch_tensorrt.dynamo._compiler import * # noqa: F403

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import logging
55
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

7+
import tensorrt as trt
78
import torch
89
import torch_tensorrt
910
from torch.export import ExportedProgram
11+
from torch_tensorrt import _enums
1012
from torch_tensorrt._Device import Device
1113
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
1214
EngineCapability,
@@ -34,6 +36,10 @@
3436
convert_module,
3537
repair_long_or_double_inputs,
3638
)
39+
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
40+
TRTInterpreter,
41+
TRTInterpreterResult,
42+
)
3743
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
3844
from torch_tensorrt.dynamo.utils import (
3945
get_torch_inputs,
@@ -319,3 +325,122 @@ def compile_module(
319325
settings.use_fast_partitioner = True
320326

321327
return partitioned_module
328+
329+
330+
def interpreter(
331+
module: torch.fx.GraphModule,
332+
inputs: Sequence[Input],
333+
settings: CompilationSettings = CompilationSettings(),
334+
name: str = "",
335+
) -> TRTInterpreterResult:
336+
torch_inputs = get_torch_inputs(inputs, settings.device)
337+
module_outputs = module(*torch_inputs)
338+
339+
if not isinstance(module_outputs, (list, tuple)):
340+
module_outputs = [module_outputs]
341+
342+
# Int64 outputs can sometimes be generated from within other operators
343+
# such as aten.sum - such outputs can be truncated
344+
output_dtypes = []
345+
for output in module_outputs:
346+
if settings.truncate_long_and_double and output.dtype == torch.float64:
347+
output_dtypes.append(torch.float32)
348+
elif settings.truncate_long_and_double and output.dtype == torch.int64:
349+
output_dtypes.append(torch.int32)
350+
else:
351+
output_dtypes.append(output.dtype)
352+
353+
interpreter = TRTInterpreter(
354+
module,
355+
inputs,
356+
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
357+
output_dtypes=output_dtypes,
358+
compilation_settings=settings,
359+
)
360+
interpreter_result = interpreter.run(
361+
workspace_size=settings.workspace_size,
362+
precision=settings.precision,
363+
profiling_verbosity=(
364+
trt.ProfilingVerbosity.VERBOSE
365+
if settings.debug
366+
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
367+
),
368+
max_aux_streams=settings.max_aux_streams,
369+
version_compatible=settings.version_compatible,
370+
optimization_level=settings.optimization_level,
371+
)
372+
return interpreter_result
373+
374+
375+
def convert_method_to_trt_engine(
376+
module: torch.fx.GraphModule,
377+
method_name: str = "forward",
378+
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
379+
device: Device = Device._current_device(),
380+
disable_tf32: bool = False,
381+
sparse_weights: bool = False,
382+
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
383+
refit: bool = False,
384+
debug: bool = False,
385+
capability: _enums.EngineCapability = _enums.EngineCapability.default,
386+
num_avg_timing_iters: int = 1,
387+
workspace_size: int = 0,
388+
dla_sram_size: int = 1048576,
389+
dla_local_dram_size: int = 1073741824,
390+
dla_global_dram_size: int = 536870912,
391+
truncate_long_and_double: int = False,
392+
calibrator: object = None,
393+
allow_shape_tensors: bool = False,
394+
) -> bytes:
395+
if debug:
396+
set_log_level(logger.parent, logging.DEBUG)
397+
398+
input_list = list(inputs) if inputs is not None else []
399+
# Prepare torch_trt inputs
400+
input_list = prepare_inputs(input_list)
401+
device = to_torch_tensorrt_device(device)
402+
403+
enabled_precisions = (
404+
enabled_precisions if enabled_precisions is not None else {torch.float}
405+
)
406+
407+
if (
408+
torch.float16 in enabled_precisions
409+
or torch_tensorrt.dtype.half in enabled_precisions
410+
):
411+
precision = torch.float16
412+
elif (
413+
torch.float32 in enabled_precisions
414+
or torch_tensorrt.dtype.float in enabled_precisions
415+
):
416+
precision = torch.float32
417+
elif len(enabled_precisions) == 0:
418+
logger.info(f"No precision specified, defaulting to {PRECISION}")
419+
precision = PRECISION
420+
else:
421+
raise ValueError(
422+
f"Precision {enabled_precisions} not supported in the Dynamo Path"
423+
)
424+
425+
compilation_options = {
426+
"precision": precision,
427+
"debug": debug,
428+
"device": device,
429+
"workspace_size": workspace_size,
430+
"truncate_long_and_double": truncate_long_and_double,
431+
"max_aux_streams": MAX_AUX_STREAMS,
432+
"version_compatible": VERSION_COMPATIBLE,
433+
"optimization_level": OPTIMIZATION_LEVEL,
434+
}
435+
436+
settings = CompilationSettings(**compilation_options)
437+
logger.info("Compilation Settings: %s\n", settings)
438+
interpreter_result = interpreter(module, input_list, settings, method_name)
439+
440+
import io
441+
442+
with io.BytesIO() as engine_bytes:
443+
engine_bytes.write(interpreter_result.engine.serialize())
444+
engine_bytearray = engine_bytes.getvalue()
445+
446+
return engine_bytearray

0 commit comments

Comments
 (0)