Skip to content

Commit 53c42b0

Browse files
committed
fix: distingush engines based on compilation settings in addition to graph structure
Signed-off-by: Naren Dasan <[email protected]>
1 parent 2518db5 commit 53c42b0

File tree

4 files changed

+25
-16
lines changed

4 files changed

+25
-16
lines changed

py/torch_tensorrt/dynamo/_engine_cache.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def get_hash(
7575
engine_specs_data = pickletools.optimize(engine_specs_data)
7676
engine_specs_hash = sha256_hash(engine_specs_data)
7777

78-
# TODO: Super first idea I had hash combination solution @Evan please iterate on this
7978
hash_val: str = graph_hash_val + input_specs_hash + engine_specs_hash
8079

8180
return hash_val
@@ -95,6 +94,8 @@ def pack(
9594
serialized_engine (bytes): serialized TRT engine
9695
input_names (List[str]): input names of TRT engine
9796
output_names (List[str]): output names of TRT engine
97+
input_specs (Sequence[Input]): input specs of TRT engine
98+
compilation_settings (CompilationSettings): compilation settings of TRT engine
9899
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
99100
100101
Returns:
@@ -121,7 +122,7 @@ def unpack(packed_obj: bytes) -> UnpackedCacheHit:
121122
packed_obj (bytes): packed blob
122123
123124
Returns:
124-
Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, CompilationSettings, weight name map
125+
Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, input specs, CompilationSettings, weight name map
125126
"""
126127
unpacked = pickle.loads(packed_obj)
127128
return (
@@ -283,11 +284,7 @@ def LRU() -> None:
283284
else:
284285
LRU()
285286

286-
def save(
287-
self,
288-
hash: str,
289-
blob: bytes,
290-
) -> None:
287+
def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
291288
blob_size = len(blob)
292289
if blob_size > self.total_engine_cache_size:
293290
_LOGGER.warning(
@@ -324,7 +321,7 @@ def save(
324321
f"The size {blob_size} is still larger than the available cache size {self.available_engine_cache_size}."
325322
)
326323

327-
def load(self, hash: str) -> Optional[bytes]:
324+
def load(self, hash: str, *args: Any, **kwargs: Any) -> Optional[bytes]:
328325
directory = os.path.join(self.engine_cache_dir, hash)
329326
if os.path.exists(directory):
330327
blob_path = os.path.join(directory, "blob.bin")

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def run(
557557
)
558558
assert (
559559
setting_compatiblity
560-
), f"Attempted to refit a prebuilt engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"
560+
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"
561561

562562
for i, e in enumerate(
563563
[
@@ -567,7 +567,7 @@ def run(
567567
):
568568
assert (
569569
e
570-
), f"Found that cached engine was built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"
570+
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"
571571

572572
_LOGGER.info(
573573
"Found the cached engine that corresponds to this graph. It is directly loaded."

py/torch_tensorrt/dynamo/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ def parse_dynamo_kwargs(
499499

500500
# If cache_built_engines and reuse_cached_engines are True but custom_engine_cache is not provided,
501501
# then create a default disk engine cache
502-
#
502+
503503
engine_cache = None
504504
if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"):
505505
assert kwargs.get(

tests/py/dynamo/models/test_engine_cache.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch_tensorrt as torch_trt
1010
import torchvision.models as models
1111
from torch.testing._internal.common_utils import TestCase
12-
from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR
12+
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
1313
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
1414
from torch_tensorrt.dynamo._settings import CompilationSettings
1515
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
@@ -160,9 +160,9 @@ def test_engine_settings_is_not_equal(self):
160160
)
161161
input_specs2 = (
162162
torch_trt.Input(
163-
min_shape=(1, 3, 300, 300),
164-
opt_shape=(100, 3, 300, 300),
165-
max_shape=(200, 3, 300, 300),
163+
min_shape=(1, 3, 224, 224),
164+
opt_shape=(100, 3, 224, 224),
165+
max_shape=(200, 3, 224, 224),
166166
),
167167
)
168168
settings2 = CompilationSettings(
@@ -192,6 +192,10 @@ def test_dynamo_compile_with_default_disk_engine_cache(self):
192192
if os.path.exists(engine_cache_dir):
193193
shutil.rmtree(engine_cache_dir)
194194

195+
def remove_timing_cache(path=TIMING_CACHE_PATH):
196+
if os.path.exists(path):
197+
os.remove(path)
198+
195199
# The 1st iteration is to measure the compilation time without engine caching
196200
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
197201
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
@@ -202,6 +206,8 @@ def test_dynamo_compile_with_default_disk_engine_cache(self):
202206
start = torch.cuda.Event(enable_timing=True)
203207
end = torch.cuda.Event(enable_timing=True)
204208
for i in range(3):
209+
remove_timing_cache()
210+
torch._dynamo.reset()
205211
if i == 0:
206212
cache_built_engines = False
207213
reuse_cached_engines = False
@@ -351,6 +357,10 @@ def test_torch_compile_with_default_disk_engine_cache(self):
351357
if os.path.exists(engine_cache_dir):
352358
shutil.rmtree(engine_cache_dir)
353359

360+
def remove_timing_cache(path=TIMING_CACHE_PATH):
361+
if os.path.exists(path):
362+
os.remove(path)
363+
354364
# The 1st iteration is to measure the compilation time without engine caching
355365
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
356366
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
@@ -361,7 +371,9 @@ def test_torch_compile_with_default_disk_engine_cache(self):
361371
start = torch.cuda.Event(enable_timing=True)
362372
end = torch.cuda.Event(enable_timing=True)
363373
for i in range(3):
364-
# remove timing cache and reset dynamo for engine caching messurement
374+
# remove timing cache and reset dynamo for engine caching measurement
375+
remove_timing_cache()
376+
torch._dynamo.reset()
365377
if i == 0:
366378
cache_built_engines = False
367379
reuse_cached_engines = False

0 commit comments

Comments
 (0)