Skip to content

Commit 65e0171

Browse files
committed
add save_engine_cache and load_engine_cache args
1 parent cc4ee82 commit 65e0171

File tree

5 files changed

+76
-25
lines changed

5 files changed

+76
-25
lines changed

examples/dynamo/engine_caching_example.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,12 @@
88
np.random.seed(0)
99
torch.manual_seed(0)
1010
size = (100, 3, 224, 224)
11-
inputs = [torch.rand(size).to("cuda")]
1211

1312
model = models.resnet18(pretrained=True).eval().to("cuda")
14-
exp_program = torch.export.export(model, tuple(inputs))
1513
enabled_precisions = {torch.float}
1614
debug = False
17-
workspace_size = 20 << 30
18-
min_block_size = 0
15+
min_block_size = 1
1916
use_python_runtime = False
20-
torch_executed_ops = {}
2117
TIMING_CACHE_PATH = "/tmp/timing_cache.bin"
2218

2319

@@ -27,17 +23,20 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
2723

2824

2925
def dynamo_path(iterations=3):
30-
outputs = []
3126
times = []
3227
start = torch.cuda.Event(enable_timing=True)
3328
end = torch.cuda.Event(enable_timing=True)
29+
inputs = [torch.rand(size).to("cuda")]
30+
exp_program = torch.export.export(model, tuple(inputs))
3431
for i in range(iterations):
3532
inputs = [torch.rand(size).to("cuda")]
36-
remove_timing_cache()
37-
if i == 0: # warmup
38-
ignore_engine_cache = True
33+
remove_timing_cache() # remove timing cache for engine caching messurement
34+
if i == 0:
35+
save_engine_cache = False
36+
load_engine_cache = False
3937
else:
40-
ignore_engine_cache = False
38+
save_engine_cache = True
39+
load_engine_cache = True
4140

4241
start.record()
4342
trt_gm = torch_trt.dynamo.compile(
@@ -47,18 +46,57 @@ def dynamo_path(iterations=3):
4746
enabled_precisions=enabled_precisions,
4847
debug=debug,
4948
min_block_size=min_block_size,
50-
torch_executed_ops=torch_executed_ops,
5149
make_refitable=True,
52-
ignore_engine_cache=ignore_engine_cache,
50+
save_engine_cache=save_engine_cache,
51+
load_engine_cache=load_engine_cache,
52+
)
53+
end.record()
54+
torch.cuda.synchronize()
55+
times.append(start.elapsed_time(end))
56+
57+
print("-----dynamo_path-----> compilation time:", times, "milliseconds")
58+
59+
60+
def compile_path(iterations=3):
61+
times = []
62+
start = torch.cuda.Event(enable_timing=True)
63+
end = torch.cuda.Event(enable_timing=True)
64+
65+
for i in range(iterations):
66+
inputs = [torch.rand(size).to("cuda")]
67+
# remove timing cache and reset dynamo for engine caching messurement
68+
remove_timing_cache()
69+
torch._dynamo.reset()
70+
71+
if i == 0:
72+
save_engine_cache = False
73+
load_engine_cache = False
74+
else:
75+
save_engine_cache = True
76+
load_engine_cache = True
77+
78+
start.record()
79+
compiled_model = torch.compile(
80+
model,
81+
backend="tensorrt",
82+
options={
83+
"use_python_runtime": use_python_runtime,
84+
"enabled_precisions": enabled_precisions,
85+
"debug": debug,
86+
"min_block_size": min_block_size,
87+
"make_refitable": True,
88+
"save_engine_cache": save_engine_cache,
89+
"load_engine_cache": load_engine_cache,
90+
},
5391
)
92+
compiled_model(*inputs) # trigger the compilation
5493
end.record()
5594
torch.cuda.synchronize()
5695
times.append(start.elapsed_time(end))
57-
outputs.append(trt_gm(*inputs))
5896

59-
print("-----dynamo_path-----> output:", outputs)
60-
print("-----dynamo_path-----> compilation time:", times, "seconds")
97+
print("-----compile_path-----> compilation time:", times, "milliseconds")
6198

6299

63100
if __name__ == "__main__":
64101
dynamo_path()
102+
compile_path()

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def compile(
7979
dryrun: bool = _defaults.DRYRUN,
8080
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
8181
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
82-
ignore_engine_cache: bool = _defaults.IGNORE_ENGINE_CACHE,
82+
save_engine_cache: bool = _defaults.SAVE_ENGINE_CACHE,
83+
load_engine_cache: bool = _defaults.LOAD_ENGINE_CACHE,
8384
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
8485
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
8586
**kwargs: Any,
@@ -142,7 +143,8 @@ def compile(
142143
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
143144
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
144145
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
145-
ignore_engine_cache (bool): Whether to ignore the cached TRT engines and recompile the module
146+
save_engine_cache (bool): Whether to save the compiled TRT engines to hard disk
147+
load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk
146148
engine_cache_dir (str): Directory to store the cached TRT engines
147149
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
148150
**kwargs: Any,
@@ -240,7 +242,8 @@ def compile(
240242
"dryrun": dryrun,
241243
"hardware_compatible": hardware_compatible,
242244
"timing_cache_path": timing_cache_path,
243-
"ignore_engine_cache": ignore_engine_cache,
245+
"save_engine_cache": save_engine_cache,
246+
"load_engine_cache": load_engine_cache,
244247
"engine_cache_dir": engine_cache_dir,
245248
"engine_cache_size": engine_cache_size,
246249
}

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
HARDWARE_COMPATIBLE = False
3333
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
3434
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
35-
IGNORE_ENGINE_CACHE = False
35+
SAVE_ENGINE_CACHE = True
36+
LOAD_ENGINE_CACHE = True
3637
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
3738
ENGINE_CACHE_SIZE = 1 << 30
3839

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
ENGINE_CACHE_SIZE,
1919
ENGINE_CAPABILITY,
2020
HARDWARE_COMPATIBLE,
21-
IGNORE_ENGINE_CACHE,
21+
LOAD_ENGINE_CACHE,
2222
MAKE_REFITABLE,
2323
MAX_AUX_STREAMS,
2424
MIN_BLOCK_SIZE,
2525
NUM_AVG_TIMING_ITERS,
2626
OPTIMIZATION_LEVEL,
2727
PASS_THROUGH_BUILD_FAILURES,
2828
REQUIRE_FULL_COMPILATION,
29+
SAVE_ENGINE_CACHE,
2930
SPARSE_WEIGHTS,
3031
TIMING_CACHE_PATH,
3132
TRUNCATE_DOUBLE,
@@ -76,7 +77,8 @@ class CompilationSettings:
7677
ouptut to a file if a string path is specified
7778
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
7879
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
79-
ignore_engine_cache (bool): Whether to ignore the cached TRT engines and recompile the module
80+
save_engine_cache (bool): Whether to save the compiled TRT engines to hard disk
81+
load_engine_cache (bool): Whether to load the compiled TRT engines from hard disk
8082
engine_cache_dir (str): Directory to store the cached TRT engines
8183
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
8284
"""
@@ -110,6 +112,7 @@ class CompilationSettings:
110112
dryrun: Union[bool, str] = DRYRUN
111113
hardware_compatible: bool = HARDWARE_COMPATIBLE
112114
timing_cache_path: str = TIMING_CACHE_PATH
113-
ignore_engine_cache: bool = IGNORE_ENGINE_CACHE
115+
save_engine_cache: bool = SAVE_ENGINE_CACHE
116+
load_engine_cache: bool = LOAD_ENGINE_CACHE
114117
engine_cache_dir: str = ENGINE_CACHE_DIR
115118
engine_cache_size: int = ENGINE_CACHE_SIZE

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,16 +331,22 @@ def run(
331331
Args:
332332
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
333333
algorithm_selector: set up algorithm selection for certain layer
334+
tactic_sources: set up tactic sources for certain layer
334335
Return:
335336
TRTInterpreterResult
336337
"""
337-
if not self.compilation_settings.ignore_engine_cache:
338-
# query the cached TRT engine
338+
if (
339+
self.compilation_settings.save_engine_cache
340+
or self.compilation_settings.load_engine_cache
341+
):
339342
engine_cache = EngineCache(
340343
self.compilation_settings.engine_cache_size,
341344
self.compilation_settings.engine_cache_dir,
342345
)
343346
hash_val = EngineCache.get_hash(self.module)
347+
348+
if self.compilation_settings.load_engine_cache:
349+
# query the cached TRT engine
344350
serialized_engine, input_names, output_names = engine_cache.load(hash_val)
345351
if serialized_engine is not None:
346352
self._input_names = input_names
@@ -390,7 +396,7 @@ def run(
390396
self._save_timing_cache(
391397
builder_config, self.compilation_settings.timing_cache_path
392398
)
393-
if not self.compilation_settings.ignore_engine_cache:
399+
if self.compilation_settings.save_engine_cache:
394400
engine_cache.save(
395401
hash_val, serialized_engine, self._input_names, self._output_names
396402
)

0 commit comments

Comments
 (0)