Skip to content

Commit 37d3311

Browse files
committed
support customizing engine cache class
1 parent 65e0171 commit 37d3311

File tree

5 files changed

+49
-18
lines changed

5 files changed

+49
-18
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import collections.abc
44
import logging
55
import warnings
6-
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
6+
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Type, Union
77

88
import torch
99
from torch.export import ExportedProgram
@@ -18,6 +18,7 @@
1818
dryrun_stats_display,
1919
parse_non_trt_nodes,
2020
)
21+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, EngineCache
2122
from torch_tensorrt.dynamo.conversion import (
2223
CompilationSettings,
2324
UnsupportedOperatorException,
@@ -83,6 +84,7 @@ def compile(
8384
load_engine_cache: bool = _defaults.LOAD_ENGINE_CACHE,
8485
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
8586
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
87+
engine_cache_class: Type[BaseEngineCache] = EngineCache,
8688
**kwargs: Any,
8789
) -> torch.fx.GraphModule:
8890
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -147,6 +149,7 @@ def compile(
147149
load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk
148150
engine_cache_dir (str): Directory to store the cached TRT engines
149151
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
152+
engine_cache_class (BaseEngineCache): Engine cache class to use for saving and loading engines. Users can provide their own engine cache class by inheriting from BaseEngineCache
150153
**kwargs: Any,
151154
Returns:
152155
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -246,6 +249,7 @@ def compile(
246249
"load_engine_cache": load_engine_cache,
247250
"engine_cache_dir": engine_cache_dir,
248251
"engine_cache_size": engine_cache_size,
252+
"engine_cache_class": engine_cache_class,
249253
}
250254

251255
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch_tensorrt._Device import Device
66
from torch_tensorrt._enums import EngineCapability, dtype
7+
from torch_tensorrt.dynamo._engine_caching import EngineCache
78

89
ENABLED_PRECISIONS = {dtype.f32}
910
DEBUG = False
@@ -36,6 +37,7 @@
3637
LOAD_ENGINE_CACHE = True
3738
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
3839
ENGINE_CACHE_SIZE = 1 << 30
40+
ENGINE_CACHE_CLASS = EngineCache
3941

4042

4143
def default_device() -> Device:

py/torch_tensorrt/dynamo/_engine_caching.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33
import logging
44
import os
55
from abc import ABC, abstractmethod
6-
from typing import List, Optional, Tuple, cast
6+
from typing import Any, List, Optional, Tuple, cast
77

88
import torch
99
from torch._inductor.codecache import FxGraphCachePickler
1010
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
11-
from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR, ENGINE_CACHE_SIZE
1211

1312
_LOGGER: logging.Logger = logging.getLogger(__name__)
1413

1514

1615
class BaseEngineCache(ABC):
1716

17+
@abstractmethod
18+
def __init__(
19+
self,
20+
*args: Any,
21+
**kwargs: Any,
22+
) -> None:
23+
pass
24+
1825
@staticmethod
1926
def get_hash(gm: torch.fx.GraphModule) -> str:
2027
"""Get the hash value of the GraphModule
@@ -42,7 +49,7 @@ def save(
4249
serialized_engine: bytes,
4350
input_names: List[str],
4451
output_names: List[str],
45-
) -> None:
52+
) -> bool:
4653
"""Save the serialized engine to hard disk
4754
4855
Args:
@@ -52,7 +59,7 @@ def save(
5259
output_names (List[str]): output names of TRT engine
5360
5461
Returns:
55-
None
62+
bool: whether the serialized engine is saved successfully
5663
"""
5764
pass
5865

@@ -82,8 +89,8 @@ class EngineCache(BaseEngineCache):
8289

8390
def __init__(
8491
self,
85-
engine_cache_size: int = ENGINE_CACHE_SIZE,
86-
engine_cache_dir: str = ENGINE_CACHE_DIR,
92+
engine_cache_size: int,
93+
engine_cache_dir: str,
8794
) -> None:
8895
self.total_engine_cache_size = engine_cache_size
8996
self.available_engine_cache_size = engine_cache_size
@@ -98,7 +105,7 @@ def has_available_cache_size(self, serialized_engine: bytes) -> bool:
98105
Returns:
99106
bool: whether the cache has available size for the serialized engine
100107
"""
101-
return serialized_engine.nbytes <= self.available_engine_cache_size
108+
return int(serialized_engine.nbytes) <= self.available_engine_cache_size
102109

103110
def clear_cache(self, size: int) -> None:
104111

@@ -113,13 +120,13 @@ def save(
113120
serialized_engine: bytes,
114121
input_names: List[str],
115122
output_names: List[str],
116-
) -> None:
117-
serialized_engine_size = serialized_engine.nbytes
123+
) -> bool:
124+
serialized_engine_size = int(serialized_engine.nbytes)
118125
if serialized_engine_size > self.total_engine_cache_size:
119126
_LOGGER.warning(
120127
f"The serialized engine cannot be saved because the size of the engine {serialized_engine_size} is larger than the total cache size {self.total_engine_cache_size}."
121128
)
122-
return
129+
return False
123130

124131
if not self.has_available_cache_size(serialized_engine):
125132
self.clear_cache(serialized_engine_size)
@@ -129,10 +136,23 @@ def save(
129136
self.engine_cache_dir,
130137
f"{hash}/engine--{input_names}--{output_names}.trt",
131138
)
132-
os.makedirs(os.path.dirname(path), exist_ok=True)
133-
with open(path, "wb") as f:
134-
f.write(serialized_engine)
139+
try:
140+
os.makedirs(os.path.dirname(path), exist_ok=True)
141+
with open(path, "wb") as f:
142+
f.write(serialized_engine)
143+
except Exception as e:
144+
_LOGGER.warning(f"Failed to save the TRT engine to {path}: {e}")
145+
return False
146+
135147
_LOGGER.info(f"A TRT engine was cached to {path}")
148+
self.available_engine_cache_size -= serialized_engine_size
149+
return True
150+
151+
else:
152+
_LOGGER.warning(
153+
f"The serialized engine {serialized_engine_size} is still larger than the available cache size {self.available_engine_cache_size}."
154+
)
155+
return False
136156

137157
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
138158
directory = os.path.join(self.engine_cache_dir, hash)

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Collection, Optional, Set, Union
2+
from typing import Collection, Optional, Set, Type, Union
33

44
from torch.fx.node import Target
55
from torch_tensorrt._Device import Device
@@ -14,6 +14,7 @@
1414
DRYRUN,
1515
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1616
ENABLED_PRECISIONS,
17+
ENGINE_CACHE_CLASS,
1718
ENGINE_CACHE_DIR,
1819
ENGINE_CACHE_SIZE,
1920
ENGINE_CAPABILITY,
@@ -36,6 +37,7 @@
3637
WORKSPACE_SIZE,
3738
default_device,
3839
)
40+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
3941

4042

4143
@dataclass
@@ -81,6 +83,7 @@ class CompilationSettings:
8183
load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk
8284
engine_cache_dir (str): Directory to store the cached TRT engines
8385
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
86+
engine_cache_class (BaseEngineCache): Engine cache class to use for saving and loading engines. Users can provide their own engine cache class by inheriting from BaseEngineCache
8487
"""
8588

8689
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -116,3 +119,4 @@ class CompilationSettings:
116119
load_engine_cache: bool = LOAD_ENGINE_CACHE
117120
engine_cache_dir: str = ENGINE_CACHE_DIR
118121
engine_cache_size: int = ENGINE_CACHE_SIZE
122+
engine_cache_class: Type[BaseEngineCache] = ENGINE_CACHE_CLASS

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from torch_tensorrt._enums import dtype
1515
from torch_tensorrt._Input import Input
1616
from torch_tensorrt.dynamo import _defaults
17-
from torch_tensorrt.dynamo._engine_caching import EngineCache
1817
from torch_tensorrt.dynamo._settings import CompilationSettings
1918
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
2019
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
@@ -339,11 +338,13 @@ def run(
339338
self.compilation_settings.save_engine_cache
340339
or self.compilation_settings.load_engine_cache
341340
):
342-
engine_cache = EngineCache(
341+
engine_cache = self.compilation_settings.engine_cache_class(
343342
self.compilation_settings.engine_cache_size,
344343
self.compilation_settings.engine_cache_dir,
345344
)
346-
hash_val = EngineCache.get_hash(self.module)
345+
hash_val = self.compilation_settings.engine_cache_class.get_hash(
346+
self.module
347+
)
347348

348349
if self.compilation_settings.load_engine_cache:
349350
# query the cached TRT engine

0 commit comments

Comments
 (0)