Skip to content

fix/feat: Add support for multiple TRT Build Args #2510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@
from torch_tensorrt.dynamo._defaults import (
DEBUG,
DEVICE,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENGINE_CAPABILITY,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REFIT,
REQUIRE_FULL_COMPILATION,
SPARSE_WEIGHTS,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
Expand Down Expand Up @@ -51,17 +59,18 @@ def compile(
inputs: Tuple[Any, ...],
*,
device: Optional[Union[Device, torch.device, str]] = DEVICE,
disable_tf32: bool = False,
sparse_weights: bool = False,
disable_tf32: bool = DISABLE_TF32,
sparse_weights: bool = SPARSE_WEIGHTS,
enabled_precisions: Set[torch.dtype] | Tuple[torch.dtype] = (torch.float32,),
refit: bool = False,
engine_capability: EngineCapability = ENGINE_CAPABILITY,
refit: bool = REFIT,
debug: bool = DEBUG,
capability: EngineCapability = EngineCapability.default,
num_avg_timing_iters: int = 1,
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS,
workspace_size: int = WORKSPACE_SIZE,
dla_sram_size: int = 1048576,
dla_local_dram_size: int = 1073741824,
dla_global_dram_size: int = 536870912,
dla_sram_size: int = DLA_SRAM_SIZE,
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE,
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE,
calibrator: object = None,
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
Expand Down Expand Up @@ -199,6 +208,13 @@ def compile(
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"refit": refit,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
"dla_global_dram_size": dla_global_dram_size,
}

settings = CompilationSettings(**compilation_options)
Expand Down
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
import torch
from tensorrt import EngineCapability
from torch_tensorrt._Device import Device

PRECISION = torch.float32
DEBUG = False
DEVICE = None
DISABLE_TF32 = False
DLA_LOCAL_DRAM_SIZE = 1073741824
DLA_GLOBAL_DRAM_SIZE = 536870912
DLA_SRAM_SIZE = 1048576
ENGINE_CAPABILITY = EngineCapability.STANDARD
WORKSPACE_SIZE = 0
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
MAX_AUX_STREAMS = None
NUM_AVG_TIMING_ITERS = 1
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
SPARSE_WEIGHTS = False
TRUNCATE_LONG_AND_DOUBLE = False
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
REFIT = False
REQUIRE_FULL_COMPILATION = False


Expand Down
25 changes: 25 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@
from typing import Optional, Set

import torch
from tensorrt import EngineCapability
from torch_tensorrt._Device import Device
from torch_tensorrt.dynamo._defaults import (
DEBUG,
DISABLE_TF32,
DLA_GLOBAL_DRAM_SIZE,
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENGINE_CAPABILITY,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
OPTIMIZATION_LEVEL,
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
REFIT,
REQUIRE_FULL_COMPILATION,
SPARSE_WEIGHTS,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
Expand Down Expand Up @@ -46,6 +55,14 @@ class CompilationSettings:
device (Device): GPU to compile the model on
require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT.
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
sparse_weights (bool): Whether to allow the builder to use sparse weights
refit (bool): Whether to build a refittable engine
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
"""

precision: torch.dtype = PRECISION
Expand All @@ -63,3 +80,11 @@ class CompilationSettings:
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
disable_tf32: bool = DISABLE_TF32
sparse_weights: bool = SPARSE_WEIGHTS
refit: bool = REFIT
engine_capability: EngineCapability = ENGINE_CAPABILITY
num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS
dla_sram_size: int = DLA_SRAM_SIZE
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
79 changes: 48 additions & 31 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -23,8 +24,6 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter

# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,6 +95,7 @@ def __init__(
self._itensor_to_tensor_meta: Dict[
trt.tensorrt.ITensor, TensorMetadata
] = dict()
self.compilation_settings = compilation_settings

# Data types for TRT Module output Tensors
self.output_dtypes = output_dtypes
Expand All @@ -118,40 +118,25 @@ def validate_conversion(self) -> Set[str]:

def run(
self,
workspace_size: int = 0,
precision: torch.dtype = torch.float32, # TODO: @peri044 Needs to be expanded to set
sparse_weights: bool = False,
disable_tf32: bool = False,
force_fp32_output: bool = False,
strict_type_constraints: bool = False,
algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
timing_cache: Optional[trt.ITimingCache] = None,
profiling_verbosity: Optional[trt.ProfilingVerbosity] = None,
tactic_sources: Optional[int] = None,
max_aux_streams: Optional[int] = None,
version_compatible: bool = False,
optimization_level: Optional[int] = None,
) -> TRTInterpreterResult:
"""
Build TensorRT engine with some configs.
Args:
workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
force_fp32_output: force output to be fp32
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
algorithm_selector: set up algorithm selection for certain layer
timing_cache: enable timing cache for TensorRT
profiling_verbosity: TensorRT logging level
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
version_compatible: Provide version forward-compatibility for engine plan files
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
Return:
TRTInterpreterResult
"""
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)

precision = self.compilation_settings.precision
# For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
# force_fp32_output=False. Overriden by specifying output_dtypes
self.output_fp16 = not force_fp32_output and precision == torch.float16
Expand All @@ -172,9 +157,9 @@ def run(

builder_config = self.builder.create_builder_config()

if workspace_size != 0:
if self.compilation_settings.workspace_size != 0:
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.WORKSPACE, workspace_size
trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
)

cache = None
Expand All @@ -187,34 +172,66 @@ def run(

if version.parse(trt.__version__) >= version.parse("8.2"):
builder_config.profiling_verbosity = (
profiling_verbosity
if profiling_verbosity
trt.ProfilingVerbosity.VERBOSE
if self.compilation_settings.debug
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
)

if version.parse(trt.__version__) >= version.parse("8.6"):
if max_aux_streams is not None:
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
builder_config.max_aux_streams = max_aux_streams
if version_compatible:
if self.compilation_settings.max_aux_streams is not None:
_LOGGER.info(
f"Setting max aux streams to {self.compilation_settings.max_aux_streams}"
)
builder_config.max_aux_streams = (
self.compilation_settings.max_aux_streams
)
if self.compilation_settings.version_compatible:
_LOGGER.info("Using version compatible")
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
if optimization_level is not None:
_LOGGER.info(f"Using optimization level {optimization_level}")
builder_config.builder_optimization_level = optimization_level
if self.compilation_settings.optimization_level is not None:
_LOGGER.info(
f"Using optimization level {self.compilation_settings.optimization_level}"
)
builder_config.builder_optimization_level = (
self.compilation_settings.optimization_level
)

builder_config.engine_capability = self.compilation_settings.engine_capability
builder_config.avg_timing_iterations = (
self.compilation_settings.num_avg_timing_iters
)

if self.compilation_settings.device.device_type == trt.DeviceType.DLA:
builder_config.DLA_core = self.compilation_settings.device.dla_core
_LOGGER.info(f"Using DLA core {self.compilation_settings.device.dla_core}")
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.DLA_MANAGED_SRAM,
self.compilation_settings.dla_sram_size,
)
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.DLA_LOCAL_DRAM,
self.compilation_settings.dla_local_dram_size,
)
builder_config.set_memory_pool_limit(
trt.MemoryPoolType.DLA_GLOBAL_DRAM,
self.compilation_settings.dla_global_dram_size,
)

if precision == torch.float16:
builder_config.set_flag(trt.BuilderFlag.FP16)

if precision == torch.int8:
builder_config.set_flag(trt.BuilderFlag.INT8)

if sparse_weights:
if self.compilation_settings.sparse_weights:
builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)

if disable_tf32:
if self.compilation_settings.disable_tf32:
builder_config.clear_flag(trt.BuilderFlag.TF32)

if self.compilation_settings.refit:
builder_config.set_flag(trt.BuilderFlag.REFIT)

if strict_type_constraints:
builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)

Expand Down
16 changes: 2 additions & 14 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import io
from typing import Sequence

import tensorrt as trt
import torch
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_torch_inputs

import tensorrt as trt


def convert_module(
module: torch.fx.GraphModule,
Expand Down Expand Up @@ -54,18 +53,7 @@ def convert_module(
output_dtypes=output_dtypes,
compilation_settings=settings,
)
interpreter_result = interpreter.run(
workspace_size=settings.workspace_size,
precision=settings.precision,
profiling_verbosity=(
trt.ProfilingVerbosity.VERBOSE
if settings.debug
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
),
max_aux_streams=settings.max_aux_streams,
version_compatible=settings.version_compatible,
optimization_level=settings.optimization_level,
)
interpreter_result = interpreter.run()

if settings.use_python_runtime:
return PythonTorchTensorRTModule(
Expand Down
8 changes: 4 additions & 4 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def run_test(
interpreter,
rtol,
atol,
precision=torch.float,
check_dtype=True,
):
with torch.no_grad():
Expand All @@ -60,7 +59,7 @@ def run_test(

mod.eval()
start = time.perf_counter()
interpreter_result = interpreter.run(precision=precision)
interpreter_result = interpreter.run()
sec = time.perf_counter() - start
_LOGGER.info(f"Interpreter run time(s): {sec}")
trt_mod = PythonTorchTensorRTModule(
Expand Down Expand Up @@ -234,7 +233,9 @@ def run_test(

# Previous instance of the interpreter auto-casted 64-bit inputs
# We replicate this behavior here
compilation_settings = CompilationSettings(truncate_long_and_double=True)
compilation_settings = CompilationSettings(
precision=precision, truncate_long_and_double=True
)

interp = TRTInterpreter(
mod,
Expand All @@ -248,7 +249,6 @@ def run_test(
interp,
rtol,
atol,
precision,
check_dtype,
)

Expand Down
Loading