Skip to content

Commit af8cce9

Browse files
committed
feat: add convert_method_to_trt_engine() for dynamo
1 parent 96fe09a commit af8cce9

File tree

3 files changed

+129
-2
lines changed

3 files changed

+129
-2
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,12 @@ def convert_method_to_trt_engine(
319319
"convert_method_to_trt_engine call is not supported for ir=fx"
320320
)
321321
elif target_ir == _IRType.dynamo:
322-
raise RuntimeError(
323-
"convert_method_to_trt_engine call is not supported for ir=dynamo."
322+
return torch_tensorrt.dynamo.convert_method_to_trt_engine( # type: ignore[no-any-return]
323+
module,
324+
inputs=inputs,
325+
method_name=method_name,
326+
enabled_precisions=enabled_precisions_set,
327+
**kwargs,
324328
)
325329
elif target_ir == _IRType.torch_compile:
326330
raise RuntimeError(

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@
1212
from ._settings import CompilationSettings
1313
from ._SourceIR import SourceIR
1414
from ._tracer import trace
15+
16+
from torch_tensorrt.dynamo._compiler import * # noqa: F403

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import logging
55
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

7+
import tensorrt as trt
78
import torch
89
import torch_tensorrt
910
from torch.export import ExportedProgram
11+
from torch_tensorrt import _enums
1012
from torch_tensorrt._Device import Device
1113
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
1214
EngineCapability,
@@ -342,3 +344,122 @@ def compile_module(
342344
settings.use_fast_partitioner = True
343345

344346
return partitioned_module
347+
348+
349+
def interpreter(
350+
module: torch.fx.GraphModule,
351+
inputs: Sequence[Input],
352+
settings: CompilationSettings = CompilationSettings(),
353+
name: str = "",
354+
) -> TRTInterpreterResult:
355+
torch_inputs = get_torch_inputs(inputs, settings.device)
356+
module_outputs = module(*torch_inputs)
357+
358+
if not isinstance(module_outputs, (list, tuple)):
359+
module_outputs = [module_outputs]
360+
361+
# Int64 outputs can sometimes be generated from within other operators
362+
# such as aten.sum - such outputs can be truncated
363+
output_dtypes = []
364+
for output in module_outputs:
365+
if settings.truncate_long_and_double and output.dtype == torch.float64:
366+
output_dtypes.append(torch.float32)
367+
elif settings.truncate_long_and_double and output.dtype == torch.int64:
368+
output_dtypes.append(torch.int32)
369+
else:
370+
output_dtypes.append(output.dtype)
371+
372+
interpreter = TRTInterpreter(
373+
module,
374+
inputs,
375+
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
376+
output_dtypes=output_dtypes,
377+
compilation_settings=settings,
378+
)
379+
interpreter_result = interpreter.run(
380+
workspace_size=settings.workspace_size,
381+
precision=settings.precision,
382+
profiling_verbosity=(
383+
trt.ProfilingVerbosity.VERBOSE
384+
if settings.debug
385+
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
386+
),
387+
max_aux_streams=settings.max_aux_streams,
388+
version_compatible=settings.version_compatible,
389+
optimization_level=settings.optimization_level,
390+
)
391+
return interpreter_result
392+
393+
394+
def convert_method_to_trt_engine(
395+
module: torch.fx.GraphModule,
396+
method_name: str = "forward",
397+
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
398+
device: Device = Device._current_device(),
399+
disable_tf32: bool = False,
400+
sparse_weights: bool = False,
401+
enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
402+
refit: bool = False,
403+
debug: bool = False,
404+
capability: _enums.EngineCapability = _enums.EngineCapability.default,
405+
num_avg_timing_iters: int = 1,
406+
workspace_size: int = 0,
407+
dla_sram_size: int = 1048576,
408+
dla_local_dram_size: int = 1073741824,
409+
dla_global_dram_size: int = 536870912,
410+
truncate_long_and_double: int = False,
411+
calibrator: object = None,
412+
allow_shape_tensors: bool = False,
413+
) -> bytes:
414+
if debug:
415+
set_log_level(logger.parent, logging.DEBUG)
416+
417+
input_list = list(inputs) if inputs is not None else []
418+
# Prepare torch_trt inputs
419+
input_list = prepare_inputs(input_list)
420+
device = to_torch_tensorrt_device(device)
421+
422+
enabled_precisions = (
423+
enabled_precisions if enabled_precisions is not None else {torch.float}
424+
)
425+
426+
if (
427+
torch.float16 in enabled_precisions
428+
or torch_tensorrt.dtype.half in enabled_precisions
429+
):
430+
precision = torch.float16
431+
elif (
432+
torch.float32 in enabled_precisions
433+
or torch_tensorrt.dtype.float in enabled_precisions
434+
):
435+
precision = torch.float32
436+
elif len(enabled_precisions) == 0:
437+
logger.info(f"No precision specified, defaulting to {PRECISION}")
438+
precision = PRECISION
439+
else:
440+
raise ValueError(
441+
f"Precision {enabled_precisions} not supported in the Dynamo Path"
442+
)
443+
444+
compilation_options = {
445+
"precision": precision,
446+
"debug": debug,
447+
"device": device,
448+
"workspace_size": workspace_size,
449+
"truncate_long_and_double": truncate_long_and_double,
450+
"max_aux_streams": MAX_AUX_STREAMS,
451+
"version_compatible": VERSION_COMPATIBLE,
452+
"optimization_level": OPTIMIZATION_LEVEL,
453+
}
454+
455+
settings = CompilationSettings(**compilation_options)
456+
logger.info("Compilation Settings: %s\n", settings)
457+
interpreter_result = interpreter(module, input_list, settings, method_name)
458+
459+
import io
460+
461+
with io.BytesIO() as engine_bytes:
462+
engine_bytes.write(interpreter_result.engine.serialize())
463+
engine_bytearray = engine_bytes.getvalue()
464+
465+
return engine_bytearray

0 commit comments

Comments
 (0)