Skip to content

Commit cc4ee82

Browse files
committed
update dynamo path
1 parent d13a46b commit cc4ee82

File tree

2 files changed

+48
-127
lines changed

2 files changed

+48
-127
lines changed
Lines changed: 35 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import time
1+
import os
22

33
import numpy as np
44
import torch
55
import torch_tensorrt as torch_trt
66
import torchvision.models as models
7-
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
87

98
np.random.seed(0)
109
torch.manual_seed(0)
@@ -19,127 +18,47 @@
1918
min_block_size = 0
2019
use_python_runtime = False
2120
torch_executed_ops = {}
21+
TIMING_CACHE_PATH = "/tmp/timing_cache.bin"
2222

2323

24-
def dynamo_path():
25-
############### warmup ###############
26-
inputs = [torch.rand(size).to("cuda")]
27-
t1 = time.time()
28-
trt_gm = torch_trt.dynamo.compile(
29-
exp_program,
30-
tuple(inputs),
31-
use_python_runtime=use_python_runtime,
32-
enabled_precisions=enabled_precisions,
33-
debug=debug,
34-
min_block_size=min_block_size,
35-
torch_executed_ops=torch_executed_ops,
36-
make_refitable=True,
37-
ignore_engine_cache=True,
38-
) # Output is a torch.fx.GraphModule
39-
t2 = time.time()
24+
def remove_timing_cache(path=TIMING_CACHE_PATH):
25+
if os.path.exists(path):
26+
os.remove(path)
4027

41-
############### compile for the first time ###############
42-
inputs = [torch.rand(size).to("cuda")]
43-
t3 = time.time()
44-
trt_gm1 = torch_trt.dynamo.compile(
45-
exp_program,
46-
tuple(inputs),
47-
use_python_runtime=use_python_runtime,
48-
enabled_precisions=enabled_precisions,
49-
debug=debug,
50-
min_block_size=min_block_size,
51-
torch_executed_ops=torch_executed_ops,
52-
make_refitable=True,
53-
ignore_engine_cache=False,
54-
) # Output is a torch.fx.GraphModule
55-
t4 = time.time()
56-
# Check the output
57-
outputs = trt_gm1(*inputs)
58-
print("----------> 1st output:", outputs)
5928

60-
############### compile for the second time ###############
61-
inputs = [torch.rand(size).to("cuda")]
62-
t5 = time.time()
63-
trt_gm2 = torch_trt.dynamo.compile(
64-
exp_program,
65-
tuple(inputs),
66-
use_python_runtime=use_python_runtime,
67-
enabled_precisions=enabled_precisions,
68-
debug=debug,
69-
min_block_size=min_block_size,
70-
torch_executed_ops=torch_executed_ops,
71-
make_refitable=True,
72-
ignore_engine_cache=False,
73-
) # Output is a torch.fx.GraphModule
74-
t6 = time.time()
75-
# Check the output
76-
outputs = trt_gm2(*inputs)
77-
print("----------> 2nd output:", outputs)
29+
def dynamo_path(iterations=3):
30+
outputs = []
31+
times = []
32+
start = torch.cuda.Event(enable_timing=True)
33+
end = torch.cuda.Event(enable_timing=True)
34+
for i in range(iterations):
35+
inputs = [torch.rand(size).to("cuda")]
36+
remove_timing_cache()
37+
if i == 0: # warmup
38+
ignore_engine_cache = True
39+
else:
40+
ignore_engine_cache = False
7841

79-
print("----------> warmup compilation time:", t2 - t1, "seconds")
80-
print("----------> 1st compilation time:", t4 - t3, "seconds")
81-
print("----------> 2nd compilation time:", t6 - t5, "seconds")
42+
start.record()
43+
trt_gm = torch_trt.dynamo.compile(
44+
exp_program,
45+
tuple(inputs),
46+
use_python_runtime=use_python_runtime,
47+
enabled_precisions=enabled_precisions,
48+
debug=debug,
49+
min_block_size=min_block_size,
50+
torch_executed_ops=torch_executed_ops,
51+
make_refitable=True,
52+
ignore_engine_cache=ignore_engine_cache,
53+
)
54+
end.record()
55+
torch.cuda.synchronize()
56+
times.append(start.elapsed_time(end))
57+
outputs.append(trt_gm(*inputs))
8258

83-
84-
def compile_path():
85-
inputs = [torch.rand(size).to("cuda")]
86-
model = models.resnet18(pretrained=True).eval().to("cuda")
87-
t1 = time.time()
88-
model = torch.compile(
89-
model,
90-
backend="tensorrt",
91-
options={
92-
"use_python_runtime": use_python_runtime,
93-
"enabled_precisions": enabled_precisions,
94-
"debug": debug,
95-
"min_block_size": min_block_size,
96-
"torch_executed_ops": torch_executed_ops,
97-
"make_refitable": True,
98-
"ignore_engine_cache": True,
99-
},
100-
)
101-
t2 = time.time()
102-
print("---------->", model(*inputs))
103-
104-
t3 = time.time()
105-
model1 = torch.compile(
106-
model,
107-
backend="tensorrt",
108-
options={
109-
"use_python_runtime": use_python_runtime,
110-
"enabled_precisions": enabled_precisions,
111-
"debug": debug,
112-
"min_block_size": min_block_size,
113-
"torch_executed_ops": torch_executed_ops,
114-
"make_refitable": True,
115-
"ignore_engine_cache": False,
116-
},
117-
)
118-
t4 = time.time()
119-
print("----------> 1st output:", model1(*inputs))
120-
121-
t5 = time.time()
122-
model2 = torch.compile(
123-
model,
124-
backend="tensorrt",
125-
options={
126-
"use_python_runtime": use_python_runtime,
127-
"enabled_precisions": enabled_precisions,
128-
"debug": debug,
129-
"min_block_size": min_block_size,
130-
"torch_executed_ops": torch_executed_ops,
131-
"make_refitable": True,
132-
"ignore_engine_cache": False,
133-
},
134-
)
135-
t6 = time.time()
136-
print("----------> 2nd output:", model2(*inputs))
137-
138-
print("----------> warmup compilation time:", t2 - t1, "seconds")
139-
print("----------> 1st compilation time:", t4 - t3, "seconds")
140-
print("----------> 2nd compilation time:", t6 - t5, "seconds")
59+
print("-----dynamo_path-----> output:", outputs)
60+
print("-----dynamo_path-----> compilation time:", times, "seconds")
14161

14262

14363
if __name__ == "__main__":
14464
dynamo_path()
145-
compile_path()

py/torch_tensorrt/dynamo/_engine_caching.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def has_available_cache_size(self, serialized_engine: bytes) -> bool:
9898
Returns:
9999
bool: whether the cache has available size for the serialized engine
100100
"""
101-
return len(serialized_engine) <= self.available_engine_cache_size
101+
return serialized_engine.nbytes <= self.available_engine_cache_size
102102

103103
def clear_cache(self, size: int) -> None:
104104

@@ -114,8 +114,8 @@ def save(
114114
input_names: List[str],
115115
output_names: List[str],
116116
) -> None:
117-
serialized_engine_size = len(serialized_engine)
118-
if serialized_engine_size <= self.total_engine_cache_size:
117+
serialized_engine_size = serialized_engine.nbytes
118+
if serialized_engine_size > self.total_engine_cache_size:
119119
_LOGGER.warning(
120120
f"The serialized engine cannot be saved because the size of the engine {serialized_engine_size} is larger than the total cache size {self.total_engine_cache_size}."
121121
)
@@ -124,13 +124,15 @@ def save(
124124
if not self.has_available_cache_size(serialized_engine):
125125
self.clear_cache(serialized_engine_size)
126126

127-
path = os.path.join(
128-
self.engine_cache_dir, f"{hash}/engine_{input_names}_{output_names}.trt"
129-
)
130-
os.makedirs(os.path.dirname(path), exist_ok=True)
131-
with open(path, "wb") as f:
132-
f.write(serialized_engine)
133-
_LOGGER.info(f"A TRT engine was cached to {path}")
127+
if self.has_available_cache_size(serialized_engine):
128+
path = os.path.join(
129+
self.engine_cache_dir,
130+
f"{hash}/engine--{input_names}--{output_names}.trt",
131+
)
132+
os.makedirs(os.path.dirname(path), exist_ok=True)
133+
with open(path, "wb") as f:
134+
f.write(serialized_engine)
135+
_LOGGER.info(f"A TRT engine was cached to {path}")
134136

135137
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
136138
directory = os.path.join(self.engine_cache_dir, hash)
@@ -141,7 +143,7 @@ def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
141143
), f"There are more than one engine {engine_list} under {directory}."
142144
path = os.path.join(directory, engine_list[0])
143145
input_names_str, output_names_str = (
144-
engine_list[0].split(".")[0].split("_")[1:]
146+
engine_list[0].split(".trt")[0].split("--")[1:]
145147
)
146148
input_names = ast.literal_eval(input_names_str)
147149
output_names = ast.literal_eval(output_names_str)

0 commit comments

Comments
 (0)