Skip to content

Commit 9910c62

Browse files
committed
fix issues from the comments
1 parent f96d3e5 commit 9910c62

File tree

4 files changed

+23
-69
lines changed

4 files changed

+23
-69
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
logger = logging.getLogger(__name__)
88

99
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
10-
from ._compiler import compile
10+
from ._compiler import compile, convert_method_to_trt_engine
1111
from ._exporter import export
1212
from ._settings import CompilationSettings
1313
from ._SourceIR import SourceIR
1414
from ._tracer import trace
15-
16-
from torch_tensorrt.dynamo._compiler import * # noqa: F403

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 15 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
66

7-
import tensorrt as trt
87
import torch
98
from torch.export import ExportedProgram
109
from torch.fx.node import Target
@@ -49,7 +48,9 @@
4948
)
5049
from torch_tensorrt.dynamo.conversion import (
5150
CompilationSettings,
51+
UnsupportedOperatorException,
5252
convert_module,
53+
interpret_module,
5354
repair_long_or_double_inputs,
5455
)
5556
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
@@ -447,51 +448,6 @@ def compile_module(
447448
return partitioned_module
448449

449450

450-
def interpreter(
451-
module: torch.fx.GraphModule,
452-
inputs: Sequence[Input],
453-
settings: CompilationSettings = CompilationSettings(),
454-
name: str = "",
455-
) -> TRTInterpreterResult:
456-
torch_inputs = get_torch_inputs(inputs, settings.device)
457-
module_outputs = module(*torch_inputs)
458-
459-
if not isinstance(module_outputs, (list, tuple)):
460-
module_outputs = [module_outputs]
461-
462-
# Int64 outputs can sometimes be generated from within other operators
463-
# such as aten.sum - such outputs can be truncated
464-
output_dtypes = []
465-
for output in module_outputs:
466-
if settings.truncate_long_and_double and output.dtype == torch.float64:
467-
output_dtypes.append(torch.float32)
468-
elif settings.truncate_long_and_double and output.dtype == torch.int64:
469-
output_dtypes.append(torch.int32)
470-
else:
471-
output_dtypes.append(output.dtype)
472-
473-
interpreter = TRTInterpreter(
474-
module,
475-
inputs,
476-
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
477-
output_dtypes=output_dtypes,
478-
compilation_settings=settings,
479-
)
480-
interpreter_result = interpreter.run(
481-
workspace_size=settings.workspace_size,
482-
precision=settings.precision,
483-
profiling_verbosity=(
484-
trt.ProfilingVerbosity.VERBOSE
485-
if settings.debug
486-
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
487-
),
488-
max_aux_streams=settings.max_aux_streams,
489-
version_compatible=settings.version_compatible,
490-
optimization_level=settings.optimization_level,
491-
)
492-
return interpreter_result
493-
494-
495451
def convert_method_to_trt_engine(
496452
module: torch.fx.GraphModule,
497453
method_name: str = "forward",
@@ -511,6 +467,9 @@ def convert_method_to_trt_engine(
511467
truncate_long_and_double: int = False,
512468
calibrator: object = None,
513469
allow_shape_tensors: bool = False,
470+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
471+
version_compatible: bool = VERSION_COMPATIBLE,
472+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
514473
) -> bytes:
515474
if debug:
516475
set_log_level(logger.parent, logging.DEBUG)
@@ -548,15 +507,20 @@ def convert_method_to_trt_engine(
548507
"device": device,
549508
"workspace_size": workspace_size,
550509
"truncate_long_and_double": truncate_long_and_double,
551-
"max_aux_streams": MAX_AUX_STREAMS,
552-
"version_compatible": VERSION_COMPATIBLE,
553-
"optimization_level": OPTIMIZATION_LEVEL,
510+
"max_aux_streams": max_aux_streams,
511+
"version_compatible": version_compatible,
512+
"optimization_level": optimization_level,
554513
}
555514

556515
settings = CompilationSettings(**compilation_options)
557516
logger.info("Compilation Settings: %s\n", settings)
558-
interpreter_result = interpreter(module, input_list, settings, method_name)
559-
517+
try:
518+
interpreter_result = interpret_module(module, input_list, settings, method_name)
519+
except UnsupportedOperatorException:
520+
logger.error(
521+
f"Conversion of module {module} not currently fully supported or convertible!",
522+
exc_info=True,
523+
)
560524
import io
561525

562526
with io.BytesIO() as engine_bytes:

py/torch_tensorrt/dynamo/conversion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from . import aten_ops_converters, ops_evaluators, prims_ops_converters
2-
from ._conversion import convert_module
2+
from ._conversion import convert_module, interpret_module
33
from ._ConversionContext import ConversionContext
44
from ._ConverterRegistry import * # noqa: F403
55
from ._TRTInterpreter import * # noqa: F403

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,20 @@
77
import torch
88
from torch_tensorrt._Input import Input
99
from torch_tensorrt.dynamo._settings import CompilationSettings
10-
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
10+
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
11+
TRTInterpreter,
12+
TRTInterpreterResult,
13+
)
1114
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1215
from torch_tensorrt.dynamo.utils import get_torch_inputs
1316

1417

15-
def convert_module(
18+
def interpret_module(
1619
module: torch.fx.GraphModule,
1720
inputs: Sequence[Input],
1821
settings: CompilationSettings = CompilationSettings(),
1922
name: str = "",
20-
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
21-
"""Convert an FX module to a TRT module
22-
Args:
23-
module: FX GraphModule to convert
24-
inputs: Sequence of Tensors representing inputs to the module
25-
settings: Compilation settings
26-
name: TRT engine name
27-
Returns:
28-
_PythonTorchTensorRTModule or TorchTensorRTModule
29-
"""
30-
# Specify module output data types to ensure TRT output types agree with
31-
# that of the equivalent Torch module
23+
) -> TRTInterpreterResult:
3224
torch_inputs = get_torch_inputs(inputs, settings.device)
3325
module_outputs = module(*torch_inputs)
3426

0 commit comments

Comments
 (0)