Skip to content

Commit e2ca04c

Browse files
committed
fix: distingush engines based on compilation settings in addition to graph structure
Signed-off-by: Naren Dasan <[email protected]>
1 parent 8154408 commit e2ca04c

18 files changed

+465
-120
lines changed

examples/dynamo/engine_caching_bert_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def compile_bert(iterations=3):
5252
"truncate_double": True,
5353
"debug": False,
5454
"min_block_size": 1,
55-
"make_refitable": True,
55+
"make_refittable": True,
5656
"cache_built_engines": cache_built_engines,
5757
"reuse_cached_engines": reuse_cached_engines,
5858
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",

examples/dynamo/engine_caching_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
6363
# in a subsequent compilation, either as part of this session or a new session, the cache will
6464
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
6565
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
66-
# the engine must be refitable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.
66+
# the engine must be refittable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.
6767

6868

6969
def torch_compile(iterations=3):
@@ -97,7 +97,7 @@ def torch_compile(iterations=3):
9797
"enabled_precisions": enabled_precisions,
9898
"debug": debug,
9999
"min_block_size": min_block_size,
100-
"make_refitable": True,
100+
"make_refittable": True,
101101
"cache_built_engines": cache_built_engines,
102102
"reuse_cached_engines": reuse_cached_engines,
103103
},
@@ -157,7 +157,7 @@ def dynamo_compile(iterations=3):
157157
enabled_precisions=enabled_precisions,
158158
debug=debug,
159159
min_block_size=min_block_size,
160-
make_refitable=True,
160+
make_refittable=True,
161161
cache_built_engines=cache_built_engines,
162162
reuse_cached_engines=reuse_cached_engines,
163163
engine_cache_size=1 << 30, # 1GB
@@ -268,7 +268,7 @@ def torch_compile_my_cache(iterations=3):
268268
"enabled_precisions": enabled_precisions,
269269
"debug": debug,
270270
"min_block_size": min_block_size,
271-
"make_refitable": True,
271+
"make_refittable": True,
272272
"cache_built_engines": cache_built_engines,
273273
"reuse_cached_engines": reuse_cached_engines,
274274
"custom_engine_cache": engine_cache,

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
settings = {
3232
"use_python": False,
3333
"enabled_precisions": {torch.float32},
34-
"make_refitable": True,
34+
"make_refittable": True,
3535
}
3636

3737
model = models.resnet18(pretrained=True).eval().to("cuda")
@@ -80,7 +80,7 @@
8080
"use_python_runtime": True,
8181
"enabled_precisions": {torch.float16},
8282
"debug": True,
83-
"make_refitable": True,
83+
"make_refittable": True,
8484
}
8585

8686
model_id = "runwayml/stable-diffusion-v1-5"

examples/dynamo/refit_engine_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@
4343

4444

4545
# %%
46-
# Make a Refitable Compilation Program
46+
# Make a refittable Compilation Program
4747
# ---------------------------------------
4848
#
4949
# The inital step is to compile a module and save it as with a normal. Note that there is an
50-
# additional parameter `make_refitable` that is set to `True`. This parameter is used to
50+
# additional parameter `make_refittable` that is set to `True`. This parameter is used to
5151
# indicate that the engine being built should support weight refitting later. Engines built without
5252
# these setttings will not be able to be refit.
5353
#
@@ -69,7 +69,7 @@
6969
debug=debug,
7070
min_block_size=min_block_size,
7171
torch_executed_ops=torch_executed_ops,
72-
make_refitable=True,
72+
make_refittable=True,
7373
) # Output is a torch.fx.GraphModule
7474

7575
# Save the graph module as an exported program

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def compile(
6060
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
6161
] = _defaults.ENABLED_PRECISIONS,
6262
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
63-
make_refitable: bool = _defaults.MAKE_REFITABLE,
63+
make_refittable: bool = _defaults.MAKE_REFITTABLE,
6464
debug: bool = _defaults.DEBUG,
6565
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
6666
workspace_size: int = _defaults.WORKSPACE_SIZE,
@@ -180,14 +180,14 @@ def compile(
180180

181181
if "refit" in kwargs.keys():
182182
warnings.warn(
183-
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
183+
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
184184
DeprecationWarning,
185185
stacklevel=2,
186186
)
187-
if make_refitable:
188-
raise ValueError("Use flag make_refitable only. Flag refit is deprecated.")
187+
if make_refittable:
188+
raise ValueError("Use flag make_refittable only. Flag refit is deprecated.")
189189
else:
190-
make_refitable = kwargs["refit"]
190+
make_refittable = kwargs["refit"]
191191

192192
engine_capability = EngineCapability._from(engine_capability)
193193

@@ -238,8 +238,8 @@ def compile(
238238
engine_cache = None
239239
if cache_built_engines or reuse_cached_engines:
240240
assert (
241-
make_refitable
242-
), "Engine caching requires make_refitable to be set to True"
241+
make_refittable
242+
), "Engine caching requires make_refittable to be set to True"
243243
engine_cache = (
244244
custom_engine_cache
245245
if custom_engine_cache is not None
@@ -270,7 +270,7 @@ def compile(
270270
"require_full_compilation": require_full_compilation,
271271
"disable_tf32": disable_tf32,
272272
"sparse_weights": sparse_weights,
273-
"make_refitable": make_refitable,
273+
"make_refittable": make_refittable,
274274
"engine_capability": engine_capability,
275275
"dla_sram_size": dla_sram_size,
276276
"dla_local_dram_size": dla_local_dram_size,
@@ -513,7 +513,7 @@ def convert_exported_program_to_serialized_trt_engine(
513513
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
514514
disable_tf32: bool = _defaults.DISABLE_TF32,
515515
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
516-
make_refitable: bool = _defaults.MAKE_REFITABLE,
516+
make_refittable: bool = _defaults.MAKE_refittable,
517517
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
518518
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
519519
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -600,7 +600,7 @@ def convert_exported_program_to_serialized_trt_engine(
600600
)
601601
if "refit" in kwargs.keys():
602602
warnings.warn(
603-
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
603+
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
604604
DeprecationWarning,
605605
stacklevel=2,
606606
)
@@ -646,7 +646,7 @@ def convert_exported_program_to_serialized_trt_engine(
646646
"require_full_compilation": require_full_compilation,
647647
"disable_tf32": disable_tf32,
648648
"sparse_weights": sparse_weights,
649-
"make_refitable": make_refitable,
649+
"make_refittable": make_refittable,
650650
"engine_capability": engine_capability,
651651
"num_avg_timing_iters": num_avg_timing_iters,
652652
"dla_sram_size": dla_sram_size,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
USE_PYTHON_RUNTIME = False
2727
USE_FAST_PARTITIONER = True
2828
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
29-
MAKE_REFITABLE = False
29+
MAKE_REFITTABLE = False
3030
REQUIRE_FULL_COMPILATION = False
3131
DRYRUN = False
3232
HARDWARE_COMPATIBLE = False

py/torch_tensorrt/dynamo/_engine_cache.py

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,34 @@
11
import copy
2+
import io
23
import logging
34
import os
45
import pickle
6+
import pickletools
57
import shutil
68
from abc import ABC, abstractmethod
79
from typing import Any, Dict, List, Optional, Tuple, cast
810

911
import torch
10-
from torch._inductor.codecache import FxGraphCachePickler
12+
from sympy.polys.matrices.dense import Sequence
13+
from torch._inductor.codecache import FxGraphCachePickler, sha256_hash
1114
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
15+
from torch_tensorrt._Input import Input
16+
from torch_tensorrt.dynamo._settings import (
17+
_SETTINGS_TO_BE_ENGINE_INVARIANT,
18+
CompilationSettings,
19+
)
1220

1321
_LOGGER: logging.Logger = logging.getLogger(__name__)
1422

23+
UnpackedCacheHit = Tuple[
24+
bytes,
25+
List[str],
26+
List[str],
27+
Sequence[Input],
28+
CompilationSettings,
29+
Optional[Dict[str, Any]],
30+
]
31+
1532

1633
class BaseEngineCache(ABC):
1734

@@ -24,7 +41,11 @@ def __init__(
2441
pass
2542

2643
@staticmethod
27-
def get_hash(gm: torch.fx.GraphModule) -> str:
44+
def get_hash(
45+
gm: torch.fx.GraphModule,
46+
input_specs: Sequence[Input],
47+
settings: CompilationSettings,
48+
) -> str:
2849
"""Get the hash value of the GraphModule
2950
3051
Args:
@@ -39,7 +60,24 @@ def get_hash(gm: torch.fx.GraphModule) -> str:
3960
for name, param in new_gm.named_parameters():
4061
param.data.zero_()
4162

42-
hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))
63+
graph_hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))
64+
65+
input_spec_strs = [str(i) for i in input_specs]
66+
with io.BytesIO() as stream:
67+
input_specs_data = pickle.dumps(input_spec_strs)
68+
input_specs_data = pickletools.optimize(input_specs_data)
69+
input_specs_hash = sha256_hash(input_specs_data)
70+
71+
invariant_engine_specs = [
72+
str(getattr(settings, field)) for field in _SETTINGS_TO_BE_ENGINE_INVARIANT
73+
]
74+
with io.BytesIO() as stream:
75+
engine_specs_data = pickle.dumps(invariant_engine_specs)
76+
engine_specs_data = pickletools.optimize(engine_specs_data)
77+
engine_specs_hash = sha256_hash(engine_specs_data)
78+
79+
# TODO: Super first idea I had hash combination solution @Evan please iterate on this
80+
hash_val: str = graph_hash_val + input_specs_hash + engine_specs_hash
4381

4482
return hash_val
4583

@@ -48,6 +86,8 @@ def pack(
4886
serialized_engine: bytes,
4987
input_names: List[str],
5088
output_names: List[str],
89+
input_specs: Tuple[Input],
90+
compilation_settings: CompilationSettings,
5191
weight_name_map: Optional[Dict[Any, Any]],
5292
) -> bytes:
5393
"""Pack serialized engine, input names, output names, and weight map into a single blob
@@ -61,35 +101,80 @@ def pack(
61101
Returns:
62102
bytes: packed blob
63103
"""
104+
105+
settings = copy.deepcopy(compilation_settings)
106+
settings.torch_executed_ops = {
107+
f"torch.ops.{op.__str__()}" for op in settings.torch_executed_ops
108+
}
109+
64110
return pickle.dumps(
65111
{
66112
"serialized_engine": bytes(serialized_engine),
67113
"input_names": input_names,
68114
"output_names": output_names,
115+
"input_specs": input_specs,
116+
"compilation_settings": settings,
69117
"weight_name_map": weight_name_map,
70118
}
71119
)
72120

73121
@staticmethod
74-
def unpack(
75-
packed_obj: bytes,
76-
) -> Tuple[bytes, List[str], List[str], Optional[Dict[Any, Any]]]:
122+
def unpack(packed_obj: bytes) -> UnpackedCacheHit:
77123
"""Unpack packed blob into serialized engine, input names, output names, and weight map
78124
79125
Args:
80126
packed_obj (bytes): packed blob
81127
82128
Returns:
83-
Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map
129+
Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, CompilationSettings, weight name map
84130
"""
85131
unpacked = pickle.loads(packed_obj)
86132
return (
87133
unpacked["serialized_engine"],
88134
unpacked["input_names"],
89135
unpacked["output_names"],
136+
unpacked["input_specs"],
137+
unpacked["compilation_settings"],
90138
unpacked["weight_name_map"],
91139
)
92140

141+
def insert(
142+
self, hash: str, entry: UnpackedCacheHit, *args: Any, **kwargs: Any
143+
) -> None:
144+
"""
145+
Insert a cache entry into the engine cache.
146+
147+
Args:
148+
hash (str): The hash value of the GraphModule.
149+
entry (Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]): The cache entry to be inserted.
150+
*args: Variable length argument list passed to ``save``.
151+
**kwargs: Arbitrary keyword arguments passed to ``save``.
152+
153+
Returns:
154+
None
155+
"""
156+
packed_cache_info = BaseEngineCache.pack(*entry)
157+
return self.save(hash, packed_cache_info, *args, **kwargs)
158+
159+
def check(self, hash: str, *args: Any, **kwargs: Any) -> Optional[UnpackedCacheHit]:
160+
"""
161+
Check if a cache entry exists for the given hash.
162+
163+
Args:
164+
hash (str): The hash value of the GraphModule.
165+
*args: Variable length argument list passed to ``load``.
166+
**kwargs: Arbitrary keyword arguments passed to ``load``.
167+
168+
Returns:
169+
Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise.
170+
"""
171+
packed_cache_info = self.load(hash, *args, **kwargs)
172+
173+
if packed_cache_info:
174+
return BaseEngineCache.unpack(packed_cache_info)
175+
else:
176+
return None
177+
93178
@abstractmethod
94179
def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
95180
"""Store blob in cache

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def refit_module_weights(
239239
)
240240

241241
# Get the settings and check the setting to be uniform
242-
settings: CompilationSettings = None
242+
settings: Optional[CompilationSettings] = None
243243
if inline_module:
244244

245245
# Obtain the settings
@@ -254,7 +254,7 @@ def refit_module_weights(
254254
]
255255
assert (
256256
encoded_metadata != ""
257-
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True"
257+
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refittable=True"
258258
settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"]
259259
# Handle torch modules
260260
compiled_submodules_map = dict(compiled_submodules)
@@ -269,8 +269,10 @@ def refit_module_weights(
269269
continue
270270
settings = submodule.settings
271271

272+
assert settings is not None
273+
272274
assert (
273-
settings.make_refitable
275+
settings.make_refittable
274276
), "Refitting is not enabled. Please recompile the engine with refit=True."
275277

276278
if settings.debug:
@@ -396,7 +398,7 @@ def refit_module_weights(
396398
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
397399
engine = compiled_submodule.engine
398400
elif isinstance(compiled_submodule, TorchTensorRTModule):
399-
engine_info = compiled_submodule.engine.__getstate__()[0]
401+
engine_info = compiled_submodule.engine.__getstate__()[0] # type: ignore[index]
400402
engine = get_engine_from_encoded_engine(
401403
engine_info[ENGINE_IDX], runtime
402404
)

0 commit comments

Comments
 (0)