6
6
import pickletools
7
7
import shutil
8
8
from abc import ABC , abstractmethod
9
- from typing import Any , Dict , List , Optional , Sequence , Tuple , cast
9
+ from typing import Any , Dict , List , Optional , Sequence , Tuple
10
10
11
11
import torch
12
- from torch ._inductor .codecache import FxGraphCachePickler , sha256_hash
13
- from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
12
+ from torch ._inductor .codecache import sha256_hash
14
13
from torch_tensorrt ._Input import Input
15
14
from torch_tensorrt .dynamo ._settings import (
16
15
_SETTINGS_TO_BE_ENGINE_INVARIANT ,
@@ -49,17 +48,38 @@ def get_hash(
49
48
50
49
Args:
51
50
gm (torch.fx.GraphModule): GraphModule to hash
51
+ input_specs (Sequence[Input]): input specs for the GraphModule
52
+ settings (CompilationSettings): compilation settings for the GraphModule
52
53
53
54
Returns:
54
55
str: hash value of the GraphModule
55
56
"""
56
- # parameters are set to 0
57
- with unset_fake_temporarily ():
58
- new_gm = copy .deepcopy (gm )
59
- for name , param in new_gm .named_parameters ():
60
- param .data .zero_ ()
61
57
62
- graph_hash_val = cast (str , FxGraphCachePickler .get_hash (new_gm ))
58
+ def canonicalize_graph (graph : torch .fx .Graph ) -> str :
59
+ """Canonicalize the graph to a string for isomorphic graph comparison
60
+
61
+ Args:
62
+ graph (torch.fx.Graph): graph to canonicalize
63
+
64
+ Returns:
65
+ str: canonicalized graph string
66
+ """
67
+ canonical_nodes = []
68
+ input_counter = 0
69
+
70
+ for node in graph .nodes :
71
+ if node .op == "placeholder" :
72
+ canonical_nodes .append (f"placeholder_input_{ input_counter } " )
73
+ input_counter += 1
74
+ else :
75
+ canonical_nodes .append (f"{ node .op } _{ node .target } " )
76
+
77
+ return " " .join (canonical_nodes )
78
+
79
+ graph_str = canonicalize_graph (gm .graph )
80
+ _LOGGER .debug (f"graph_str:\n { graph_str } " )
81
+
82
+ graph_hash = sha256_hash (graph_str .encode ())
63
83
64
84
input_spec_strs = [str (i ) for i in input_specs ]
65
85
with io .BytesIO () as stream :
@@ -75,7 +95,7 @@ def get_hash(
75
95
engine_specs_data = pickletools .optimize (engine_specs_data )
76
96
engine_specs_hash = sha256_hash (engine_specs_data )
77
97
78
- hash_val : str = graph_hash_val + input_specs_hash + engine_specs_hash
98
+ hash_val : str = graph_hash + input_specs_hash + engine_specs_hash
79
99
80
100
return hash_val
81
101
0 commit comments