Skip to content

Commit 5a47b1e

Browse files
committed
fix issues from the comments
1 parent d34966e commit 5a47b1e

File tree

4 files changed

+45
-75
lines changed

4 files changed

+45
-75
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
logger = logging.getLogger(__name__)
88

99
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
10-
from ._compiler import compile
10+
from ._compiler import compile, convert_method_to_trt_engine
1111
from ._exporter import export
1212
from ._settings import CompilationSettings
1313
from ._SourceIR import SourceIR
1414
from ._tracer import trace
15-
16-
from torch_tensorrt.dynamo._compiler import * # noqa: F403

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 15 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

7-
import tensorrt as trt
87
import torch
98
import torch_tensorrt
109
from torch.export import ExportedProgram
@@ -33,13 +32,11 @@
3332
)
3433
from torch_tensorrt.dynamo.conversion import (
3534
CompilationSettings,
35+
UnsupportedOperatorException,
3636
convert_module,
37+
interpret_module,
3738
repair_long_or_double_inputs,
3839
)
39-
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
40-
TRTInterpreter,
41-
TRTInterpreterResult,
42-
)
4340
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
4441
from torch_tensorrt.dynamo.utils import (
4542
get_torch_inputs,
@@ -327,51 +324,6 @@ def compile_module(
327324
return partitioned_module
328325

329326

330-
def interpreter(
331-
module: torch.fx.GraphModule,
332-
inputs: Sequence[Input],
333-
settings: CompilationSettings = CompilationSettings(),
334-
name: str = "",
335-
) -> TRTInterpreterResult:
336-
torch_inputs = get_torch_inputs(inputs, settings.device)
337-
module_outputs = module(*torch_inputs)
338-
339-
if not isinstance(module_outputs, (list, tuple)):
340-
module_outputs = [module_outputs]
341-
342-
# Int64 outputs can sometimes be generated from within other operators
343-
# such as aten.sum - such outputs can be truncated
344-
output_dtypes = []
345-
for output in module_outputs:
346-
if settings.truncate_long_and_double and output.dtype == torch.float64:
347-
output_dtypes.append(torch.float32)
348-
elif settings.truncate_long_and_double and output.dtype == torch.int64:
349-
output_dtypes.append(torch.int32)
350-
else:
351-
output_dtypes.append(output.dtype)
352-
353-
interpreter = TRTInterpreter(
354-
module,
355-
inputs,
356-
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
357-
output_dtypes=output_dtypes,
358-
compilation_settings=settings,
359-
)
360-
interpreter_result = interpreter.run(
361-
workspace_size=settings.workspace_size,
362-
precision=settings.precision,
363-
profiling_verbosity=(
364-
trt.ProfilingVerbosity.VERBOSE
365-
if settings.debug
366-
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
367-
),
368-
max_aux_streams=settings.max_aux_streams,
369-
version_compatible=settings.version_compatible,
370-
optimization_level=settings.optimization_level,
371-
)
372-
return interpreter_result
373-
374-
375327
def convert_method_to_trt_engine(
376328
module: torch.fx.GraphModule,
377329
method_name: str = "forward",
@@ -391,6 +343,9 @@ def convert_method_to_trt_engine(
391343
truncate_long_and_double: int = False,
392344
calibrator: object = None,
393345
allow_shape_tensors: bool = False,
346+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
347+
version_compatible: bool = VERSION_COMPATIBLE,
348+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
394349
) -> bytes:
395350
if debug:
396351
set_log_level(logger.parent, logging.DEBUG)
@@ -428,15 +383,20 @@ def convert_method_to_trt_engine(
428383
"device": device,
429384
"workspace_size": workspace_size,
430385
"truncate_long_and_double": truncate_long_and_double,
431-
"max_aux_streams": MAX_AUX_STREAMS,
432-
"version_compatible": VERSION_COMPATIBLE,
433-
"optimization_level": OPTIMIZATION_LEVEL,
386+
"max_aux_streams": max_aux_streams,
387+
"version_compatible": version_compatible,
388+
"optimization_level": optimization_level,
434389
}
435390

436391
settings = CompilationSettings(**compilation_options)
437392
logger.info("Compilation Settings: %s\n", settings)
438-
interpreter_result = interpreter(module, input_list, settings, method_name)
439-
393+
try:
394+
interpreter_result = interpret_module(module, input_list, settings, method_name)
395+
except UnsupportedOperatorException:
396+
logger.error(
397+
f"Conversion of module {module} not currently fully supported or convertible!",
398+
exc_info=True,
399+
)
440400
import io
441401

442402
with io.BytesIO() as engine_bytes:

py/torch_tensorrt/dynamo/conversion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from . import aten_ops_converters, ops_evaluators, prims_ops_converters
2-
from ._conversion import convert_module
2+
from ._conversion import convert_module, interpret_module
33
from ._ConversionContext import ConversionContext
44
from ._ConverterRegistry import * # noqa: F403
55
from ._TRTInterpreter import * # noqa: F403

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,24 @@
33
import io
44
from typing import Sequence
55

6+
import tensorrt as trt
67
import torch
78
from torch_tensorrt._Input import Input
89
from torch_tensorrt.dynamo._settings import CompilationSettings
9-
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
10+
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
11+
TRTInterpreter,
12+
TRTInterpreterResult,
13+
)
1014
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1115
from torch_tensorrt.dynamo.utils import get_torch_inputs
1216

13-
import tensorrt as trt
14-
1517

16-
def convert_module(
18+
def interpret_module(
1719
module: torch.fx.GraphModule,
1820
inputs: Sequence[Input],
1921
settings: CompilationSettings = CompilationSettings(),
2022
name: str = "",
21-
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
22-
"""Convert an FX module to a TRT module
23-
Args:
24-
module: FX GraphModule to convert
25-
inputs: Sequence of Tensors representing inputs to the module
26-
settings: Compilation settings
27-
name: TRT engine name
28-
Returns:
29-
_PythonTorchTensorRTModule or TorchTensorRTModule
30-
"""
31-
# Specify module output data types to ensure TRT output types agree with
32-
# that of the equivalent Torch module
23+
) -> TRTInterpreterResult:
3324
torch_inputs = get_torch_inputs(inputs, settings.device)
3425
module_outputs = module(*torch_inputs)
3526

@@ -66,6 +57,27 @@ def convert_module(
6657
version_compatible=settings.version_compatible,
6758
optimization_level=settings.optimization_level,
6859
)
60+
return interpreter_result
61+
62+
63+
def convert_module(
64+
module: torch.fx.GraphModule,
65+
inputs: Sequence[Input],
66+
settings: CompilationSettings = CompilationSettings(),
67+
name: str = "",
68+
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
69+
"""Convert an FX module to a TRT module
70+
Args:
71+
module: FX GraphModule to convert
72+
inputs: Sequence of Tensors representing inputs to the module
73+
settings: Compilation settings
74+
name: TRT engine name
75+
Returns:
76+
_PythonTorchTensorRTModule or TorchTensorRTModule
77+
"""
78+
# Specify module output data types to ensure TRT output types agree with
79+
# that of the equivalent Torch module
80+
interpreter_result = interpret_module(module, inputs, settings, name)
6981

7082
if settings.use_python_runtime:
7183
return PythonTorchTensorRTModule(

0 commit comments

Comments
 (0)