Skip to content

Commit afb1516

Browse files
authored
fix: get_hash function for engine caching (#3293)
1 parent 0841f34 commit afb1516

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

py/torch_tensorrt/dynamo/_engine_cache.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
import pickletools
77
import shutil
88
from abc import ABC, abstractmethod
9-
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
9+
from typing import Any, Dict, List, Optional, Sequence, Tuple
1010

1111
import torch
12-
from torch._inductor.codecache import FxGraphCachePickler, sha256_hash
13-
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
12+
from torch._inductor.codecache import sha256_hash
1413
from torch_tensorrt._Input import Input
1514
from torch_tensorrt.dynamo._settings import (
1615
_SETTINGS_TO_BE_ENGINE_INVARIANT,
@@ -49,17 +48,38 @@ def get_hash(
4948
5049
Args:
5150
gm (torch.fx.GraphModule): GraphModule to hash
51+
input_specs (Sequence[Input]): input specs for the GraphModule
52+
settings (CompilationSettings): compilation settings for the GraphModule
5253
5354
Returns:
5455
str: hash value of the GraphModule
5556
"""
56-
# parameters are set to 0
57-
with unset_fake_temporarily():
58-
new_gm = copy.deepcopy(gm)
59-
for name, param in new_gm.named_parameters():
60-
param.data.zero_()
6157

62-
graph_hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))
58+
def canonicalize_graph(graph: torch.fx.Graph) -> str:
59+
"""Canonicalize the graph to a string for isomorphic graph comparison
60+
61+
Args:
62+
graph (torch.fx.Graph): graph to canonicalize
63+
64+
Returns:
65+
str: canonicalized graph string
66+
"""
67+
canonical_nodes = []
68+
input_counter = 0
69+
70+
for node in graph.nodes:
71+
if node.op == "placeholder":
72+
canonical_nodes.append(f"placeholder_input_{input_counter}")
73+
input_counter += 1
74+
else:
75+
canonical_nodes.append(f"{node.op}_{node.target}")
76+
77+
return " ".join(canonical_nodes)
78+
79+
graph_str = canonicalize_graph(gm.graph)
80+
_LOGGER.debug(f"graph_str:\n {graph_str}")
81+
82+
graph_hash = sha256_hash(graph_str.encode())
6383

6484
input_spec_strs = [str(i) for i in input_specs]
6585
with io.BytesIO() as stream:
@@ -75,7 +95,7 @@ def get_hash(
7595
engine_specs_data = pickletools.optimize(engine_specs_data)
7696
engine_specs_hash = sha256_hash(engine_specs_data)
7797

78-
hash_val: str = graph_hash_val + input_specs_hash + engine_specs_hash
98+
hash_val: str = graph_hash + input_specs_hash + engine_specs_hash
7999

80100
return hash_val
81101

tests/py/dynamo/models/test_engine_cache.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
231231
)
232232
end.record()
233233
torch.cuda.synchronize()
234-
torch._dynamo.reset()
235234
times.append(start.elapsed_time(end))
236235
results.append(trt_gm(*inputs))
237236

@@ -396,7 +395,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
396395
"reuse_cached_engines": reuse_cached_engines,
397396
"engine_cache_dir": engine_cache_dir,
398397
"engine_cache_size": 1 << 30, # 1GB
399-
"torch_executed_ops": {"torch.ops.aten.relu.default"},
400398
},
401399
)
402400
results.append(compiled_model(*inputs)) # trigger the compilation
@@ -441,7 +439,6 @@ def test_torch_compile_with_custom_engine_cache(self):
441439
start = torch.cuda.Event(enable_timing=True)
442440
end = torch.cuda.Event(enable_timing=True)
443441
for i in range(3):
444-
# remove timing cache and reset dynamo for engine caching messurement
445442
if i == 0:
446443
cache_built_engines = False
447444
reuse_cached_engines = False
@@ -462,7 +459,6 @@ def test_torch_compile_with_custom_engine_cache(self):
462459
"cache_built_engines": cache_built_engines,
463460
"reuse_cached_engines": reuse_cached_engines,
464461
"custom_engine_cache": custom_engine_cache,
465-
"torch_executed_ops": {"torch.ops.aten.relu.default"},
466462
},
467463
)
468464
results.append(compiled_model(*inputs)) # trigger the compilation

0 commit comments

Comments
 (0)