Skip to content

Commit eda7e94

Browse files
committed
fix issues from the comments
1 parent af8cce9 commit eda7e94

File tree

4 files changed

+23
-69
lines changed

4 files changed

+23
-69
lines changed

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
logger = logging.getLogger(__name__)
88

99
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
10-
from ._compiler import compile
10+
from ._compiler import compile, convert_method_to_trt_engine
1111
from ._exporter import export
1212
from ._settings import CompilationSettings
1313
from ._SourceIR import SourceIR
1414
from ._tracer import trace
15-
16-
from torch_tensorrt.dynamo._compiler import * # noqa: F403

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 15 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import logging
55
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

7-
import tensorrt as trt
87
import torch
98
import torch_tensorrt
109
from torch.export import ExportedProgram
@@ -41,7 +40,9 @@
4140
)
4241
from torch_tensorrt.dynamo.conversion import (
4342
CompilationSettings,
43+
UnsupportedOperatorException,
4444
convert_module,
45+
interpret_module,
4546
repair_long_or_double_inputs,
4647
)
4748
from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
@@ -346,51 +347,6 @@ def compile_module(
346347
return partitioned_module
347348

348349

349-
def interpreter(
350-
module: torch.fx.GraphModule,
351-
inputs: Sequence[Input],
352-
settings: CompilationSettings = CompilationSettings(),
353-
name: str = "",
354-
) -> TRTInterpreterResult:
355-
torch_inputs = get_torch_inputs(inputs, settings.device)
356-
module_outputs = module(*torch_inputs)
357-
358-
if not isinstance(module_outputs, (list, tuple)):
359-
module_outputs = [module_outputs]
360-
361-
# Int64 outputs can sometimes be generated from within other operators
362-
# such as aten.sum - such outputs can be truncated
363-
output_dtypes = []
364-
for output in module_outputs:
365-
if settings.truncate_long_and_double and output.dtype == torch.float64:
366-
output_dtypes.append(torch.float32)
367-
elif settings.truncate_long_and_double and output.dtype == torch.int64:
368-
output_dtypes.append(torch.int32)
369-
else:
370-
output_dtypes.append(output.dtype)
371-
372-
interpreter = TRTInterpreter(
373-
module,
374-
inputs,
375-
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
376-
output_dtypes=output_dtypes,
377-
compilation_settings=settings,
378-
)
379-
interpreter_result = interpreter.run(
380-
workspace_size=settings.workspace_size,
381-
precision=settings.precision,
382-
profiling_verbosity=(
383-
trt.ProfilingVerbosity.VERBOSE
384-
if settings.debug
385-
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
386-
),
387-
max_aux_streams=settings.max_aux_streams,
388-
version_compatible=settings.version_compatible,
389-
optimization_level=settings.optimization_level,
390-
)
391-
return interpreter_result
392-
393-
394350
def convert_method_to_trt_engine(
395351
module: torch.fx.GraphModule,
396352
method_name: str = "forward",
@@ -410,6 +366,9 @@ def convert_method_to_trt_engine(
410366
truncate_long_and_double: int = False,
411367
calibrator: object = None,
412368
allow_shape_tensors: bool = False,
369+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
370+
version_compatible: bool = VERSION_COMPATIBLE,
371+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
413372
) -> bytes:
414373
if debug:
415374
set_log_level(logger.parent, logging.DEBUG)
@@ -447,15 +406,20 @@ def convert_method_to_trt_engine(
447406
"device": device,
448407
"workspace_size": workspace_size,
449408
"truncate_long_and_double": truncate_long_and_double,
450-
"max_aux_streams": MAX_AUX_STREAMS,
451-
"version_compatible": VERSION_COMPATIBLE,
452-
"optimization_level": OPTIMIZATION_LEVEL,
409+
"max_aux_streams": max_aux_streams,
410+
"version_compatible": version_compatible,
411+
"optimization_level": optimization_level,
453412
}
454413

455414
settings = CompilationSettings(**compilation_options)
456415
logger.info("Compilation Settings: %s\n", settings)
457-
interpreter_result = interpreter(module, input_list, settings, method_name)
458-
416+
try:
417+
interpreter_result = interpret_module(module, input_list, settings, method_name)
418+
except UnsupportedOperatorException:
419+
logger.error(
420+
f"Conversion of module {module} not currently fully supported or convertible!",
421+
exc_info=True,
422+
)
459423
import io
460424

461425
with io.BytesIO() as engine_bytes:

py/torch_tensorrt/dynamo/conversion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from . import aten_ops_converters, ops_evaluators, prims_ops_converters
2-
from ._conversion import convert_module
2+
from ._conversion import convert_module, interpret_module
33
from ._ConversionContext import ConversionContext
44
from ._ConverterRegistry import * # noqa: F403
55
from ._TRTInterpreter import * # noqa: F403

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,20 @@
77
import torch
88
from torch_tensorrt._Input import Input
99
from torch_tensorrt.dynamo._settings import CompilationSettings
10-
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
10+
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
11+
TRTInterpreter,
12+
TRTInterpreterResult,
13+
)
1114
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
1215
from torch_tensorrt.dynamo.utils import get_torch_inputs
1316

1417

15-
def convert_module(
18+
def interpret_module(
1619
module: torch.fx.GraphModule,
1720
inputs: Sequence[Input],
1821
settings: CompilationSettings = CompilationSettings(),
1922
name: str = "",
20-
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
21-
"""Convert an FX module to a TRT module
22-
Args:
23-
module: FX GraphModule to convert
24-
inputs: Sequence of Tensors representing inputs to the module
25-
settings: Compilation settings
26-
name: TRT engine name
27-
Returns:
28-
_PythonTorchTensorRTModule or TorchTensorRTModule
29-
"""
30-
# Specify module output data types to ensure TRT output types agree with
31-
# that of the equivalent Torch module
23+
) -> TRTInterpreterResult:
3224
torch_inputs = get_torch_inputs(inputs, settings.device)
3325
module_outputs = module(*torch_inputs)
3426

0 commit comments

Comments
 (0)