Skip to content

Commit 9723376

Browse files
committed
fix issues from comments, add more unit tests
1 parent 24f5c2c commit 9723376

File tree

7 files changed

+223
-41
lines changed

7 files changed

+223
-41
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def compile(
8686
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
8787
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
8888
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
89-
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
90-
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
89+
engine_cache_dir: Optional[str] = _defaults.ENGINE_CACHE_DIR,
90+
engine_cache_size: Optional[int] = _defaults.ENGINE_CACHE_SIZE,
9191
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
9292
**kwargs: Any,
9393
) -> torch.fx.GraphModule:
@@ -156,8 +156,8 @@ def compile(
156156
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
157157
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
158158
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
159-
engine_cache_dir (str): Directory to store the cached TRT engines
160-
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
159+
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
160+
engine_cache_size (Optional[int]): Maximum hard-disk space (bytes) to use for the engine cache, default is 1GB. If the cache exceeds this size, the oldest engines will be removed by default
161161
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
162162
**kwargs: Any,
163163
Returns:
@@ -235,12 +235,16 @@ def compile(
235235
gm = post_lowering(gm)
236236
logger.debug("Lowered Input graph: " + str(gm.graph))
237237

238+
engine_cache = None
238239
if cache_built_engines or reuse_cached_engines:
239240
assert (
240241
make_refitable
241242
), "Engine caching requires make_refitable to be set to True"
242-
if custom_engine_cache is None:
243-
custom_engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size)
243+
engine_cache = (
244+
custom_engine_cache
245+
if custom_engine_cache is not None
246+
else DiskEngineCache(engine_cache_dir, engine_cache_size)
247+
)
244248

245249
compilation_options = {
246250
"enabled_precisions": (
@@ -277,12 +281,13 @@ def compile(
277281
"lazy_engine_init": lazy_engine_init,
278282
"cache_built_engines": cache_built_engines,
279283
"reuse_cached_engines": reuse_cached_engines,
280-
"custom_engine_cache": custom_engine_cache,
281284
}
282285

283286
settings = CompilationSettings(**compilation_options)
284287
logger.info("Compilation Settings: %s\n", settings)
285-
trt_gm = compile_module(gm, trt_arg_inputs, trt_kwarg_inputs, settings)
288+
trt_gm = compile_module(
289+
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
290+
)
286291
return trt_gm
287292

288293

@@ -291,6 +296,7 @@ def compile_module(
291296
sample_arg_inputs: Sequence[Input],
292297
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
293298
settings: CompilationSettings = CompilationSettings(),
299+
engine_cache: Optional[BaseEngineCache] = None,
294300
) -> torch.fx.GraphModule:
295301
"""Compile a traced FX module
296302
@@ -301,6 +307,7 @@ def compile_module(
301307
arg_inputs: Inputs to the module
302308
kwarg_inputs: kwargs to the module
303309
settings: Compilation settings
310+
engine_cache: Engine cache instance to store/load compiled engines
304311
Returns:
305312
Compiled FX GraphModule
306313
"""
@@ -480,6 +487,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
480487
submodule_inputs,
481488
settings=settings,
482489
name=name,
490+
engine_cache=engine_cache,
483491
)
484492

485493
trt_modules[name] = trt_module

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from torch_tensorrt.dynamo._defaults import (
88
ASSUME_DYNAMIC_SHAPE_SUPPORT,
99
CACHE_BUILT_ENGINES,
10-
CUSTOM_ENGINE_CACHE,
1110
DEBUG,
1211
DISABLE_TF32,
1312
DLA_GLOBAL_DRAM_SIZE,
@@ -36,7 +35,6 @@
3635
WORKSPACE_SIZE,
3736
default_device,
3837
)
39-
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
4038

4139

4240
@dataclass
@@ -80,7 +78,6 @@ class CompilationSettings:
8078
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
8179
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
8280
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
83-
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache
8481
"""
8582

8683
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -115,4 +112,3 @@ class CompilationSettings:
115112
lazy_engine_init: bool = LAZY_ENGINE_INIT
116113
cache_built_engines: bool = CACHE_BUILT_ENGINES
117114
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
118-
custom_engine_cache: Optional[BaseEngineCache] = CUSTOM_ENGINE_CACHE

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,21 +48,23 @@ def torch_tensorrt_backend(
4848
def aot_torch_tensorrt_aten_backend(
4949
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
5050
) -> torch.nn.Module:
51-
settings = parse_dynamo_kwargs(kwargs)
52-
return _pretraced_backend(gm, sample_inputs, settings)
51+
settings, engine_cache = parse_dynamo_kwargs(kwargs)
52+
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
5353

5454

5555
def _pretraced_backend(
5656
gm: torch.fx.GraphModule,
5757
sample_inputs: Sequence[Any],
5858
settings: CompilationSettings = CompilationSettings(),
59+
engine_cache: Any = None,
5960
) -> torch.fx.GraphModule | Callable[..., Any]:
6061
"""Helper function to manage translation of traced FX module to TRT engines
6162
6263
Args:
6364
module: FX GraphModule to convert
6465
inputs: Inputs to the module
6566
settings: Compilation settings
67+
engine_cache: Engine cache instance
6668
Returns:
6769
Compiled FX GraphModule
6870
"""
@@ -109,6 +111,7 @@ def _pretraced_backend(
109111
gm,
110112
torchtrt_inputs,
111113
settings=settings,
114+
engine_cache=engine_cache,
112115
)
113116
return trt_compiled
114117
except (AssertionError, RuntimeError):

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torch_tensorrt._enums import dtype
2828
from torch_tensorrt._Input import Input
2929
from torch_tensorrt.dynamo import _defaults
30+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
3031
from torch_tensorrt.dynamo._settings import CompilationSettings
3132
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
3233
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
@@ -71,6 +72,7 @@ def __init__(
7172
logger_level: trt.ILogger.Severity = trt.ILogger.Severity.WARNING,
7273
output_dtypes: Optional[Sequence[dtype]] = None,
7374
compilation_settings: CompilationSettings = CompilationSettings(),
75+
engine_cache: Optional[BaseEngineCache] = None,
7476
):
7577
super().__init__(module)
7678

@@ -126,6 +128,9 @@ def __init__(
126128
self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {}
127129
self.weight_name_map: Optional[dict[str, Any]] = None
128130

131+
# Engine cache for storing and reusing TRT engines
132+
self.engine_cache = engine_cache
133+
129134
def validate_conversion(self) -> Set[str]:
130135
missing_converters: Set[str] = set()
131136

@@ -521,22 +526,22 @@ def run(
521526
Return:
522527
TRTInterpreterResult
523528
"""
524-
if (
525-
self.compilation_settings.custom_engine_cache is not None
526-
): # custom_engine_cache could be None if this function is called from convert_exported_program_to_serialized_trt_engine etc.
529+
# self.engine_cache could be None if:
530+
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
531+
# 2) both cache_built_engines and reuse_cached_engines are False
532+
if self.engine_cache is not None:
527533
if (
528534
self.compilation_settings.cache_built_engines
529535
or self.compilation_settings.reuse_cached_engines
530536
):
531-
engine_cache = self.compilation_settings.custom_engine_cache
532-
hash_val = engine_cache.get_hash(self.module)
537+
hash_val = self.engine_cache.get_hash(self.module)
533538

534539
if self.compilation_settings.reuse_cached_engines:
535540
# query the cached TRT engine
536-
blob = engine_cache.load(hash_val)
541+
blob = self.engine_cache.load(hash_val)
537542
if blob is not None: # hit the cache
538543
serialized_engine, input_names, output_names, weight_name_map = (
539-
engine_cache.unpack(blob)
544+
self.engine_cache.unpack(blob)
540545
)
541546
self._input_names = input_names
542547
self._output_names = output_names
@@ -605,16 +610,16 @@ def run(
605610
builder_config, self.compilation_settings.timing_cache_path
606611
)
607612
if (
608-
self.compilation_settings.custom_engine_cache is not None
613+
self.engine_cache is not None
609614
and self.compilation_settings.cache_built_engines
610615
):
611-
blob = engine_cache.pack(
616+
blob = self.engine_cache.pack(
612617
serialized_engine,
613618
self._input_names,
614619
self._output_names,
615620
self.weight_name_map,
616621
)
617-
engine_cache.save(hash_val, blob)
622+
self.engine_cache.save(hash_val, blob)
618623

619624
with io.BytesIO() as engine_bytes:
620625
engine_bytes.write(serialized_engine)

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch_tensorrt._enums import dtype
1111
from torch_tensorrt._features import ENABLED_FEATURES
1212
from torch_tensorrt._Input import Input
13+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
1314
from torch_tensorrt.dynamo._settings import CompilationSettings
1415
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
1516
TRTInterpreter,
@@ -70,6 +71,7 @@ def interpret_module_to_result(
7071
settings: CompilationSettings = CompilationSettings(),
7172
arg_inputs: Optional[Sequence[Input]] = None,
7273
kwarg_inputs: Optional[dict[str, Any]] = None,
74+
engine_cache: Optional[BaseEngineCache] = None,
7375
) -> TRTInterpreterResult:
7476
"""Interpret an FX module to a TRTInterpreterResult
7577
Args:
@@ -79,6 +81,7 @@ def interpret_module_to_result(
7981
arg_inputs: Sequence of Tensors representing inputs to the module.
8082
kwarg_inputs: A dictionary of Tensors representing inputs to the module.
8183
settings: Compilation settings
84+
engine_cache: Engine cache instance
8285
Returns:
8386
TRTInterpreterResult
8487
"""
@@ -105,6 +108,7 @@ def interpret_module_to_result(
105108
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
106109
output_dtypes=output_dtypes,
107110
compilation_settings=settings,
111+
engine_cache=engine_cache,
108112
)
109113
interpreter_result = interpreter.run()
110114
return interpreter_result
@@ -115,17 +119,21 @@ def convert_module(
115119
inputs: Sequence[Input],
116120
settings: CompilationSettings = CompilationSettings(),
117121
name: str = "",
122+
engine_cache: Optional[BaseEngineCache] = None,
118123
) -> PythonTorchTensorRTModule | TorchTensorRTModule:
119124
"""Convert an FX module to a TRT module
120125
Args:
121126
module: FX GraphModule to convert
122127
inputs: Sequence of Tensors representing inputs to the module
123128
settings: Compilation settings
124129
name: TRT engine name
130+
engine_cache: Engine cache instance
125131
Returns:
126132
PythonTorchTensorRTModule or TorchTensorRTModule
127133
"""
128-
interpreter_result = interpret_module_to_result(module, inputs, settings)
134+
interpreter_result = interpret_module_to_result(
135+
module, inputs, settings, engine_cache=engine_cache
136+
)
129137

130138
rt_cls = PythonTorchTensorRTModule
131139

py/torch_tensorrt/dynamo/utils.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from dataclasses import fields, replace
55
from enum import Enum
6-
from typing import Any, Callable, Dict, Optional, Sequence, Union
6+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
77

88
import numpy as np
99
import tensorrt as trt
@@ -12,6 +12,7 @@
1212
from torch_tensorrt._enums import dtype
1313
from torch_tensorrt._Input import Input
1414
from torch_tensorrt.dynamo import _defaults
15+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
1516
from torch_tensorrt.dynamo._settings import CompilationSettings
1617

1718
from packaging import version
@@ -301,7 +302,9 @@ def to_torch_tensorrt_device(
301302
return Device._from(device)
302303

303304

304-
def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
305+
def parse_dynamo_kwargs(
306+
kwargs: Any,
307+
) -> Tuple[CompilationSettings, Optional[BaseEngineCache]]:
305308
"""Parses the kwargs field of a Dynamo backend
306309
307310
Args:
@@ -360,11 +363,15 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
360363

361364
# If cache_built_engines and reuse_cached_engines are True but custom_engine_cache is not provided,
362365
# then create a default disk engine cache
366+
engine_cache = None
363367
if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"):
364368
assert kwargs.get(
365369
"make_refitable"
366370
), "Engine caching requires make_refitable to be set to True"
367-
if settings.custom_engine_cache is None:
371+
372+
if kwargs.get("custom_engine_cache") is not None:
373+
engine_cache = kwargs.get("custom_engine_cache")
374+
else:
368375
from torch_tensorrt.dynamo._engine_caching import DiskEngineCache
369376

370377
engine_cache_dir = kwargs.get(
@@ -373,13 +380,11 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
373380
engine_cache_size = kwargs.get(
374381
"engine_cache_size", _defaults.ENGINE_CACHE_SIZE
375382
)
376-
settings.custom_engine_cache = DiskEngineCache(
377-
engine_cache_dir, engine_cache_size
378-
)
383+
engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size)
379384

380385
logger.info("Compilation Settings: %s\n", settings)
381386

382-
return settings
387+
return settings, engine_cache
383388

384389

385390
def req_torch_version(min_torch_version: str = "2.dev") -> Callable[..., Any]:

0 commit comments

Comments
 (0)