Skip to content

Commit 6533d5c

Browse files
committed
force using slow refit, add unit tests
1 parent 034c2ba commit 6533d5c

File tree

3 files changed

+185
-18
lines changed

3 files changed

+185
-18
lines changed

examples/dynamo/engine_caching_example.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
np.random.seed(0)
1212
torch.manual_seed(0)
13-
size = (100, 3, 224, 224)
1413

1514
model = models.resnet18(pretrained=True).eval().to("cuda")
1615
enabled_precisions = {torch.float}
@@ -24,7 +23,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
2423
os.remove(path)
2524

2625

27-
def dynamo_path(iterations=3):
26+
def dynamo_compile(iterations=3):
2827
times = []
2928
start = torch.cuda.Event(enable_timing=True)
3029
end = torch.cuda.Event(enable_timing=True)
@@ -42,7 +41,7 @@ def dynamo_path(iterations=3):
4241
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
4342
for i in range(iterations):
4443
inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")]
45-
remove_timing_cache() # remove timing cache for engine caching messurement
44+
remove_timing_cache() # remove timing cache just for engine caching messurement
4645
if i == 0:
4746
cache_built_engines = False
4847
reuse_cached_engines = False
@@ -63,11 +62,15 @@ def dynamo_path(iterations=3):
6362
reuse_cached_engines=reuse_cached_engines,
6463
engine_cache_size=1 << 30, # 1GB
6564
)
65+
# output = trt_gm(*inputs)
6666
end.record()
6767
torch.cuda.synchronize()
6868
times.append(start.elapsed_time(end))
6969

70-
print("-----dynamo_path-----> compilation time:\n", times, "milliseconds")
70+
print("----------------dynamo_compile----------------")
71+
print("disable engine caching, used:", times[0], "ms")
72+
print("enable engine caching to cache engines, used:", times[1], "ms")
73+
print("enable engine caching to reuse engines, used:", times[2], "ms")
7174

7275

7376
# Custom Engine Cache
@@ -84,11 +87,13 @@ def save(
8487
blob: bytes,
8588
prefix: str = "blob",
8689
):
90+
if not os.path.exists(self.engine_cache_dir):
91+
os.makedirs(self.engine_cache_dir, exist_ok=True)
92+
8793
path = os.path.join(
8894
self.engine_cache_dir,
8995
f"{prefix}_{hash}.bin",
9096
)
91-
os.makedirs(path, exist_ok=True)
9297
with open(path, "wb") as f:
9398
f.write(blob)
9499

@@ -101,7 +106,7 @@ def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
101106
return None
102107

103108

104-
def compile_path(iterations=3):
109+
def torch_compile(iterations=3):
105110
times = []
106111
engine_cache = MyEngineCache("/tmp/your_dir")
107112
start = torch.cuda.Event(enable_timing=True)
@@ -112,8 +117,8 @@ def compile_path(iterations=3):
112117
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
113118
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
114119
for i in range(iterations):
115-
inputs = [torch.rand(size).to("cuda")]
116-
# remove timing cache and reset dynamo for engine caching messurement
120+
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
121+
# remove timing cache and reset dynamo just for engine caching messurement
117122
remove_timing_cache()
118123
torch._dynamo.reset()
119124

@@ -129,7 +134,7 @@ def compile_path(iterations=3):
129134
model,
130135
backend="tensorrt",
131136
options={
132-
"use_python_runtime": use_python_runtime,
137+
"use_python_runtime": True,
133138
"enabled_precisions": enabled_precisions,
134139
"debug": debug,
135140
"min_block_size": min_block_size,
@@ -144,9 +149,12 @@ def compile_path(iterations=3):
144149
torch.cuda.synchronize()
145150
times.append(start.elapsed_time(end))
146151

147-
print("-----compile_path-----> compilation time:\n", times, "milliseconds")
152+
print("----------------torch_compile----------------")
153+
print("disable engine caching, used:", times[0], "ms")
154+
print("enable engine caching to cache engines, used:", times[1], "ms")
155+
print("enable engine caching to reuse engines, used:", times[2], "ms")
148156

149157

150158
if __name__ == "__main__":
151-
dynamo_path()
152-
# compile_path()
159+
dynamo_compile()
160+
torch_compile()

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -502,25 +502,31 @@ def run(
502502
"Found the cached engine that corresponds to this graph. It is directly loaded."
503503
)
504504

505+
runtime = trt.Runtime(TRT_LOGGER)
506+
engine = runtime.deserialize_cuda_engine(serialized_engine)
507+
505508
from torch_tensorrt.dynamo._refit import (
506509
_refit_single_trt_engine_with_gm,
507510
)
508511

509-
runtime = trt.Runtime(TRT_LOGGER)
510-
engine = runtime.deserialize_cuda_engine(serialized_engine)
511-
512+
# TODO: Fast refit is problematic for now. It will fail if the engine has batch_norm layers.
513+
# We set weight_name_map=None to use slow refit anyway for now. Will fix it in the future.
512514
_refit_single_trt_engine_with_gm(
513515
new_gm=self.module,
514516
old_engine=engine,
515517
input_list=self.input_specs,
516518
settings=self.compilation_settings,
517-
weight_name_map=weight_name_map,
519+
weight_name_map=None,
518520
)
519521

520-
serialized_engine = bytes(engine.serialize())
522+
serialized_engine = engine.serialize()
523+
524+
with io.BytesIO() as engine_bytes:
525+
engine_bytes.write(serialized_engine)
526+
engine_str = engine_bytes.getvalue()
521527

522528
return TRTInterpreterResult(
523-
serialized_engine,
529+
engine_str,
524530
self._input_names,
525531
self._output_names,
526532
self.weight_name_map,
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# type: ignore
2+
import os
3+
import shutil
4+
import unittest
5+
from typing import Optional
6+
7+
import torch
8+
import torch_tensorrt as torch_trt
9+
import torchvision.models as models
10+
from torch.testing._internal.common_utils import TestCase
11+
from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR
12+
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
13+
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
14+
15+
assertions = unittest.TestCase()
16+
17+
18+
class MyEngineCache(BaseEngineCache):
19+
def __init__(
20+
self,
21+
engine_cache_dir: str,
22+
) -> None:
23+
self.engine_cache_dir = engine_cache_dir
24+
25+
def save(
26+
self,
27+
hash: str,
28+
blob: bytes,
29+
prefix: str = "blob",
30+
):
31+
if not os.path.exists(self.engine_cache_dir):
32+
os.makedirs(self.engine_cache_dir, exist_ok=True)
33+
34+
path = os.path.join(
35+
self.engine_cache_dir,
36+
f"{prefix}_{hash}.bin",
37+
)
38+
with open(path, "wb") as f:
39+
f.write(blob)
40+
41+
def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
42+
path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin")
43+
if os.path.exists(path):
44+
with open(path, "rb") as f:
45+
blob = f.read()
46+
return blob
47+
return None
48+
49+
50+
class TestEngineCache(TestCase):
51+
52+
def test_dynamo_compile(self):
53+
model = models.resnet18(pretrained=True).eval().to("cuda")
54+
example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),)
55+
# Mark the dim0 of inputs as dynamic
56+
batch = torch.export.Dim("batch", min=1, max=200)
57+
exp_program = torch.export.export(
58+
model, args=example_inputs, dynamic_shapes={"x": {0: batch}}
59+
)
60+
engine_cache_dir = ENGINE_CACHE_DIR
61+
if os.path.exists(engine_cache_dir):
62+
shutil.rmtree(engine_cache_dir)
63+
# The 1st iteration is to measure the compilation time without engine caching
64+
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
65+
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
66+
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
67+
inputs = [torch.rand((128, 3, 224, 224)).to("cuda")]
68+
results = []
69+
for i in range(3):
70+
if i == 0:
71+
cache_built_engines = False
72+
reuse_cached_engines = False
73+
else:
74+
cache_built_engines = True
75+
reuse_cached_engines = True
76+
77+
trt_gm = torch_trt.dynamo.compile(
78+
exp_program,
79+
tuple(inputs),
80+
use_python_runtime=False,
81+
enabled_precisions={torch.float},
82+
debug=False,
83+
min_block_size=1,
84+
make_refitable=True,
85+
cache_built_engines=cache_built_engines,
86+
reuse_cached_engines=reuse_cached_engines,
87+
engine_cache_size=1 << 30, # 1GB
88+
)
89+
results.append(trt_gm(*inputs))
90+
91+
cos_sim = cosine_similarity(results[0], results[1])
92+
assertions.assertTrue(
93+
cos_sim > COSINE_THRESHOLD,
94+
msg=f"test_dynamo_compile TRT without engine caching doesn't match with that with engine caching. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
95+
)
96+
97+
cos_sim = cosine_similarity(results[1], results[2])
98+
assertions.assertTrue(
99+
cos_sim > COSINE_THRESHOLD,
100+
msg=f"test_dynamo_compile TRT with engine caching doesn't match with that cached engine. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
101+
)
102+
103+
def test_torch_compile(self):
104+
# Custom Engine Cache
105+
model = models.resnet18(pretrained=True).eval().to("cuda")
106+
107+
engine_cache_dir = "/tmp/your_dir"
108+
if os.path.exists(engine_cache_dir):
109+
shutil.rmtree(engine_cache_dir)
110+
111+
engine_cache = MyEngineCache(engine_cache_dir)
112+
# The 1st iteration is to measure the compilation time without engine caching
113+
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
114+
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
115+
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
116+
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
117+
results = []
118+
for i in range(3):
119+
# remove timing cache and reset dynamo for engine caching messurement
120+
if i == 0:
121+
cache_built_engines = False
122+
reuse_cached_engines = False
123+
else:
124+
cache_built_engines = True
125+
reuse_cached_engines = True
126+
127+
compiled_model = torch.compile(
128+
model,
129+
backend="tensorrt",
130+
options={
131+
"use_python_runtime": True,
132+
"enabled_precisions": {torch.float},
133+
"debug": False,
134+
"min_block_size": 1,
135+
"make_refitable": True,
136+
"cache_built_engines": cache_built_engines,
137+
"reuse_cached_engines": reuse_cached_engines,
138+
"custom_engine_cache": engine_cache, # use custom engine cache
139+
},
140+
)
141+
results.append(compiled_model(*inputs)) # trigger the compilation
142+
143+
cos_sim = cosine_similarity(results[0], results[1])
144+
assertions.assertTrue(
145+
cos_sim > COSINE_THRESHOLD,
146+
msg=f"test_torch_compile TRT without engine caching doesn't match with that with engine caching. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
147+
)
148+
149+
cos_sim = cosine_similarity(results[1], results[2])
150+
assertions.assertTrue(
151+
cos_sim > COSINE_THRESHOLD,
152+
msg=f"test_torch_compile TRT with engine caching doesn't match with that cached engine. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
153+
)

0 commit comments

Comments
 (0)