Skip to content

Commit db2a523

Browse files
committed
refactor and add LRU to clear cache
1 parent 37d3311 commit db2a523

File tree

6 files changed

+171
-38
lines changed

6 files changed

+171
-38
lines changed

examples/dynamo/engine_caching_example.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
1+
import ast
2+
import logging
13
import os
4+
from typing import List, Optional, Tuple
25

36
import numpy as np
47
import torch
58
import torch_tensorrt as torch_trt
69
import torchvision.models as models
10+
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
11+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
12+
13+
_LOGGER: logging.Logger = logging.getLogger(__name__)
14+
715

816
np.random.seed(0)
917
torch.manual_seed(0)
@@ -14,7 +22,6 @@
1422
debug = False
1523
min_block_size = 1
1624
use_python_runtime = False
17-
TIMING_CACHE_PATH = "/tmp/timing_cache.bin"
1825

1926

2027
def remove_timing_cache(path=TIMING_CACHE_PATH):
@@ -26,10 +33,16 @@ def dynamo_path(iterations=3):
2633
times = []
2734
start = torch.cuda.Event(enable_timing=True)
2835
end = torch.cuda.Event(enable_timing=True)
29-
inputs = [torch.rand(size).to("cuda")]
30-
exp_program = torch.export.export(model, tuple(inputs))
36+
37+
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
38+
# Mark the dim0 of inputs as dynamic
39+
batch = torch.export.Dim("batch", min=1, max=200)
40+
exp_program = torch.export.export(
41+
model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
42+
)
43+
3144
for i in range(iterations):
32-
inputs = [torch.rand(size).to("cuda")]
45+
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
3346
remove_timing_cache() # remove timing cache for engine caching messurement
3447
if i == 0:
3548
save_engine_cache = False
@@ -49,6 +62,7 @@ def dynamo_path(iterations=3):
4962
make_refitable=True,
5063
save_engine_cache=save_engine_cache,
5164
load_engine_cache=load_engine_cache,
65+
engine_cache_size=1 << 30, # 1GB
5266
)
5367
end.record()
5468
torch.cuda.synchronize()
@@ -57,8 +71,65 @@ def dynamo_path(iterations=3):
5771
print("-----dynamo_path-----> compilation time:", times, "milliseconds")
5872

5973

74+
# Custom Engine Cache
75+
class MyEngineCache(BaseEngineCache):
76+
77+
def __init__(
78+
self,
79+
engine_cache_size: int,
80+
engine_cache_dir: str,
81+
) -> None:
82+
self.total_engine_cache_size = engine_cache_size
83+
self.available_engine_cache_size = engine_cache_size
84+
self.engine_cache_dir = engine_cache_dir
85+
86+
def save(
87+
self,
88+
hash: str,
89+
serialized_engine: bytes,
90+
input_names: List[str],
91+
output_names: List[str],
92+
) -> bool:
93+
path = os.path.join(
94+
self.engine_cache_dir,
95+
f"{hash}/engine--{input_names}--{output_names}.trt",
96+
)
97+
try:
98+
os.makedirs(os.path.dirname(path), exist_ok=True)
99+
with open(path, "wb") as f:
100+
f.write(serialized_engine)
101+
except Exception as e:
102+
_LOGGER.warning(f"Failed to save the TRT engine to {path}: {e}")
103+
return False
104+
105+
_LOGGER.info(f"A TRT engine was cached to {path}")
106+
serialized_engine_size = int(serialized_engine.nbytes)
107+
self.available_engine_cache_size -= serialized_engine_size
108+
return True
109+
110+
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
111+
directory = os.path.join(self.engine_cache_dir, hash)
112+
if os.path.exists(directory):
113+
engine_list = os.listdir(directory)
114+
assert (
115+
len(engine_list) == 1
116+
), f"There are more than one engine {engine_list} under {directory}."
117+
path = os.path.join(directory, engine_list[0])
118+
input_names_str, output_names_str = (
119+
engine_list[0].split(".trt")[0].split("--")[1:]
120+
)
121+
input_names = ast.literal_eval(input_names_str)
122+
output_names = ast.literal_eval(output_names_str)
123+
with open(path, "rb") as f:
124+
serialized_engine = f.read()
125+
return serialized_engine, input_names, output_names
126+
else:
127+
return None, [], []
128+
129+
60130
def compile_path(iterations=3):
61131
times = []
132+
engine_cache = MyEngineCache(200 * (1 << 20), "/tmp/your_dir")
62133
start = torch.cuda.Event(enable_timing=True)
63134
end = torch.cuda.Event(enable_timing=True)
64135

@@ -87,6 +158,7 @@ def compile_path(iterations=3):
87158
"make_refitable": True,
88159
"save_engine_cache": save_engine_cache,
89160
"load_engine_cache": load_engine_cache,
161+
"engine_cache_instance": engine_cache, # use custom engine cache
90162
},
91163
)
92164
compiled_model(*inputs) # trigger the compilation

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import collections.abc
44
import logging
55
import warnings
6-
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Type, Union
6+
from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
77

88
import torch
99
from torch.export import ExportedProgram
@@ -84,7 +84,7 @@ def compile(
8484
load_engine_cache: bool = _defaults.LOAD_ENGINE_CACHE,
8585
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
8686
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
87-
engine_cache_class: Type[BaseEngineCache] = EngineCache,
87+
engine_cache_instance: Optional[BaseEngineCache] = None,
8888
**kwargs: Any,
8989
) -> torch.fx.GraphModule:
9090
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -149,7 +149,7 @@ def compile(
149149
load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk
150150
engine_cache_dir (str): Directory to store the cached TRT engines
151151
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
152-
engine_cache_class (BaseEngineCache): Engine cache class to use for saving and loading engines. Users can provide their own engine cache class by inheriting from BaseEngineCache
152+
engine_cache_instance (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache
153153
**kwargs: Any,
154154
Returns:
155155
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -213,6 +213,11 @@ def compile(
213213
gm = post_lowering(gm, torch_inputs)
214214
logger.debug("Lowered Input graph: " + str(gm.graph))
215215

216+
if engine_cache_instance is None:
217+
engine_cache_instance = EngineCacheInstanceCreator.get_creator(
218+
engine_cache_size, engine_cache_dir
219+
).engine_cache_instance
220+
216221
compilation_options = {
217222
"enabled_precisions": (
218223
enabled_precisions if enabled_precisions else _defaults.ENABLED_PRECISIONS
@@ -249,7 +254,7 @@ def compile(
249254
"load_engine_cache": load_engine_cache,
250255
"engine_cache_dir": engine_cache_dir,
251256
"engine_cache_size": engine_cache_size,
252-
"engine_cache_class": engine_cache_class,
257+
"engine_cache_instance": engine_cache_instance,
253258
}
254259

255260
settings = CompilationSettings(**compilation_options)
@@ -666,3 +671,21 @@ def convert_module_to_trt_engine(
666671
engine_bytearray = engine_bytes.getvalue()
667672

668673
return engine_bytearray
674+
675+
676+
class EngineCacheInstanceCreator:
677+
engine_cache_creator = None
678+
679+
def __init__(self, engine_cache_size: int, engine_cache_dir: str) -> None:
680+
self.engine_cache_instance = EngineCache(
681+
engine_cache_size=engine_cache_size,
682+
engine_cache_dir=engine_cache_dir,
683+
)
684+
685+
@classmethod
686+
def get_creator(
687+
cls, engine_cache_size: int, engine_cache_dir: str
688+
) -> EngineCacheInstanceCreator:
689+
if cls.engine_cache_creator is None:
690+
cls.engine_cache_creator = cls(engine_cache_size, engine_cache_dir)
691+
return cls.engine_cache_creator

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,16 @@
3232
DRYRUN = False
3333
HARDWARE_COMPATIBLE = False
3434
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
35-
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
35+
TIMING_CACHE_PATH = os.path.join(
36+
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
37+
)
3638
SAVE_ENGINE_CACHE = True
3739
LOAD_ENGINE_CACHE = True
3840
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
39-
ENGINE_CACHE_SIZE = 1 << 30
40-
ENGINE_CACHE_CLASS = EngineCache
41+
ENGINE_CACHE_SIZE = 1073741824
42+
ENGINE_CACHE_INSTANCE = EngineCache(
43+
engine_cache_size=ENGINE_CACHE_SIZE, engine_cache_dir=ENGINE_CACHE_DIR
44+
)
4145

4246

4347
def default_device() -> Device:

py/torch_tensorrt/dynamo/_engine_caching.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
import copy
33
import logging
44
import os
5+
import shutil
56
from abc import ABC, abstractmethod
6-
from typing import Any, List, Optional, Tuple, cast
7+
from typing import Any, Dict, List, Optional, Tuple, cast
78

89
import torch
910
from torch._inductor.codecache import FxGraphCachePickler
@@ -75,15 +76,6 @@ def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
7576
"""
7677
pass
7778

78-
@abstractmethod
79-
def clear_cache(self, size: int) -> None:
80-
"""Clear the cache to make sure at least `size` bytes are available
81-
82-
Args:
83-
size (int): the needed size
84-
"""
85-
pass
86-
8779

8880
class EngineCache(BaseEngineCache):
8981

@@ -95,6 +87,7 @@ def __init__(
9587
self.total_engine_cache_size = engine_cache_size
9688
self.available_engine_cache_size = engine_cache_size
9789
self.engine_cache_dir = engine_cache_dir
90+
self.hash2size_map: Dict[str, int] = {}
9891

9992
def has_available_cache_size(self, serialized_engine: bytes) -> bool:
10093
"""Check if the cache has available space for saving the serialized engine
@@ -107,12 +100,53 @@ def has_available_cache_size(self, serialized_engine: bytes) -> bool:
107100
"""
108101
return int(serialized_engine.nbytes) <= self.available_engine_cache_size
109102

110-
def clear_cache(self, size: int) -> None:
103+
def clear_cache(self, needed_min_size: int) -> bool:
104+
"""Clear the cache to make sure at least `needed_min_size` bytes are available, if possible
111105
112-
def LRU() -> None:
113-
pass
106+
Args:
107+
needed_min_size (int): the minimum needed size
114108
115-
pass
109+
Returns:
110+
bool: whether the cache is cleared successfully
111+
"""
112+
113+
def LRU() -> bool:
114+
"""Clear the Least Recently Used engine in the cache"""
115+
# Get the list of engine directories
116+
engines_hash_values = os.listdir(self.engine_cache_dir)
117+
# Sort the engine directories by modification time (oldest first)
118+
engines_hash_values.sort(
119+
key=lambda x: os.path.getmtime(os.path.join(self.engine_cache_dir, x))
120+
)
121+
# Iterate over the engine directories and remove the oldest ones until enough space is available
122+
for engine_hash in engines_hash_values:
123+
if self.available_engine_cache_size >= needed_min_size:
124+
break
125+
engine_path = os.path.join(self.engine_cache_dir, engine_hash)
126+
try:
127+
# Remove the entire directory
128+
shutil.rmtree(engine_path)
129+
# Update the available cache size
130+
self.available_engine_cache_size += self.hash2size_map.pop(
131+
engine_hash, 0
132+
)
133+
_LOGGER.info(
134+
f"Removed the engine cache at {engine_path}, available cache size: {self.available_engine_cache_size} bytes."
135+
)
136+
except Exception as e:
137+
_LOGGER.warning(
138+
f"Failed to clear the engine cache at {engine_path}: {e}"
139+
)
140+
return False
141+
return True
142+
143+
if not os.path.exists(self.engine_cache_dir):
144+
return False
145+
146+
_LOGGER.info(
147+
f"Total cache size: {self.total_engine_cache_size} bytes; available cache size: {self.available_engine_cache_size} bytes. Clearing the cache to make sure at least {needed_min_size} bytes are available."
148+
)
149+
return LRU()
116150

117151
def save(
118152
self,
@@ -128,9 +162,11 @@ def save(
128162
)
129163
return False
130164

165+
# Check if there is enough available cache size for the serialized engine
131166
if not self.has_available_cache_size(serialized_engine):
132167
self.clear_cache(serialized_engine_size)
133168

169+
# Save the serialized engine to the cache directory
134170
if self.has_available_cache_size(serialized_engine):
135171
path = os.path.join(
136172
self.engine_cache_dir,
@@ -140,12 +176,14 @@ def save(
140176
os.makedirs(os.path.dirname(path), exist_ok=True)
141177
with open(path, "wb") as f:
142178
f.write(serialized_engine)
179+
self.hash2size_map[hash] = serialized_engine_size
180+
self.available_engine_cache_size -= serialized_engine_size
181+
_LOGGER.info(f"A TRT engine was cached to {path}")
182+
143183
except Exception as e:
144184
_LOGGER.warning(f"Failed to save the TRT engine to {path}: {e}")
145185
return False
146186

147-
_LOGGER.info(f"A TRT engine was cached to {path}")
148-
self.available_engine_cache_size -= serialized_engine_size
149187
return True
150188

151189
else:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Collection, Optional, Set, Type, Union
2+
from typing import Collection, Optional, Set, Union
33

44
from torch.fx.node import Target
55
from torch_tensorrt._Device import Device
@@ -14,8 +14,8 @@
1414
DRYRUN,
1515
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1616
ENABLED_PRECISIONS,
17-
ENGINE_CACHE_CLASS,
1817
ENGINE_CACHE_DIR,
18+
ENGINE_CACHE_INSTANCE,
1919
ENGINE_CACHE_SIZE,
2020
ENGINE_CAPABILITY,
2121
HARDWARE_COMPATIBLE,
@@ -83,7 +83,7 @@ class CompilationSettings:
8383
load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk
8484
engine_cache_dir (str): Directory to store the cached TRT engines
8585
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
86-
engine_cache_class (BaseEngineCache): Engine cache class to use for saving and loading engines. Users can provide their own engine cache class by inheriting from BaseEngineCache
86+
engine_cache_instance (BaseEngineCache): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache
8787
"""
8888

8989
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -119,4 +119,4 @@ class CompilationSettings:
119119
load_engine_cache: bool = LOAD_ENGINE_CACHE
120120
engine_cache_dir: str = ENGINE_CACHE_DIR
121121
engine_cache_size: int = ENGINE_CACHE_SIZE
122-
engine_cache_class: Type[BaseEngineCache] = ENGINE_CACHE_CLASS
122+
engine_cache_instance: BaseEngineCache = ENGINE_CACHE_INSTANCE

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ def _save_timing_cache(
303303
This is called after a TensorRT engine is built. Save the timing cache
304304
"""
305305
timing_cache = builder_config.get_timing_cache()
306+
os.makedirs(os.path.dirname(timing_cache_path), exist_ok=True)
306307
with open(timing_cache_path, "wb") as timing_cache_file:
307308
timing_cache_file.write(memoryview(timing_cache.serialize()))
308309

@@ -338,13 +339,8 @@ def run(
338339
self.compilation_settings.save_engine_cache
339340
or self.compilation_settings.load_engine_cache
340341
):
341-
engine_cache = self.compilation_settings.engine_cache_class(
342-
self.compilation_settings.engine_cache_size,
343-
self.compilation_settings.engine_cache_dir,
344-
)
345-
hash_val = self.compilation_settings.engine_cache_class.get_hash(
346-
self.module
347-
)
342+
engine_cache = self.compilation_settings.engine_cache_instance
343+
hash_val = engine_cache.get_hash(self.module)
348344

349345
if self.compilation_settings.load_engine_cache:
350346
# query the cached TRT engine

0 commit comments

Comments
 (0)