|
| 1 | +import os |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | +import torch_tensorrt as torch_trt |
| 7 | +import torchvision.models as models |
| 8 | +from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH |
| 9 | +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache |
| 10 | + |
| 11 | +np.random.seed(0) |
| 12 | +torch.manual_seed(0) |
| 13 | + |
| 14 | +model = models.resnet18(pretrained=True).eval().to("cuda") |
| 15 | +enabled_precisions = {torch.float} |
| 16 | +debug = False |
| 17 | +min_block_size = 1 |
| 18 | +use_python_runtime = False |
| 19 | + |
| 20 | + |
| 21 | +def remove_timing_cache(path=TIMING_CACHE_PATH): |
| 22 | + if os.path.exists(path): |
| 23 | + os.remove(path) |
| 24 | + |
| 25 | + |
| 26 | +def dynamo_compile(iterations=3): |
| 27 | + times = [] |
| 28 | + start = torch.cuda.Event(enable_timing=True) |
| 29 | + end = torch.cuda.Event(enable_timing=True) |
| 30 | + |
| 31 | + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) |
| 32 | + # Mark the dim0 of inputs as dynamic |
| 33 | + batch = torch.export.Dim("batch", min=1, max=200) |
| 34 | + exp_program = torch.export.export( |
| 35 | + model, args=example_inputs, dynamic_shapes={"x": {0: batch}} |
| 36 | + ) |
| 37 | + |
| 38 | + # The 1st iteration is to measure the compilation time without engine caching |
| 39 | + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. |
| 40 | + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. |
| 41 | + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. |
| 42 | + for i in range(iterations): |
| 43 | + inputs = [torch.rand((100 + i, 3, 224, 224)).to("cuda")] |
| 44 | + remove_timing_cache() # remove timing cache just for engine caching messurement |
| 45 | + if i == 0: |
| 46 | + cache_built_engines = False |
| 47 | + reuse_cached_engines = False |
| 48 | + else: |
| 49 | + cache_built_engines = True |
| 50 | + reuse_cached_engines = True |
| 51 | + |
| 52 | + start.record() |
| 53 | + trt_gm = torch_trt.dynamo.compile( |
| 54 | + exp_program, |
| 55 | + tuple(inputs), |
| 56 | + use_python_runtime=use_python_runtime, |
| 57 | + enabled_precisions=enabled_precisions, |
| 58 | + debug=debug, |
| 59 | + min_block_size=min_block_size, |
| 60 | + make_refitable=True, |
| 61 | + cache_built_engines=cache_built_engines, |
| 62 | + reuse_cached_engines=reuse_cached_engines, |
| 63 | + engine_cache_size=1 << 30, # 1GB |
| 64 | + ) |
| 65 | + # output = trt_gm(*inputs) |
| 66 | + end.record() |
| 67 | + torch.cuda.synchronize() |
| 68 | + times.append(start.elapsed_time(end)) |
| 69 | + |
| 70 | + print("----------------dynamo_compile----------------") |
| 71 | + print("disable engine caching, used:", times[0], "ms") |
| 72 | + print("enable engine caching to cache engines, used:", times[1], "ms") |
| 73 | + print("enable engine caching to reuse engines, used:", times[2], "ms") |
| 74 | + |
| 75 | + |
| 76 | +# Custom Engine Cache |
| 77 | +class MyEngineCache(BaseEngineCache): |
| 78 | + def __init__( |
| 79 | + self, |
| 80 | + engine_cache_dir: str, |
| 81 | + ) -> None: |
| 82 | + self.engine_cache_dir = engine_cache_dir |
| 83 | + |
| 84 | + def save( |
| 85 | + self, |
| 86 | + hash: str, |
| 87 | + blob: bytes, |
| 88 | + prefix: str = "blob", |
| 89 | + ): |
| 90 | + if not os.path.exists(self.engine_cache_dir): |
| 91 | + os.makedirs(self.engine_cache_dir, exist_ok=True) |
| 92 | + |
| 93 | + path = os.path.join( |
| 94 | + self.engine_cache_dir, |
| 95 | + f"{prefix}_{hash}.bin", |
| 96 | + ) |
| 97 | + with open(path, "wb") as f: |
| 98 | + f.write(blob) |
| 99 | + |
| 100 | + def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]: |
| 101 | + path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin") |
| 102 | + if os.path.exists(path): |
| 103 | + with open(path, "rb") as f: |
| 104 | + blob = f.read() |
| 105 | + return blob |
| 106 | + return None |
| 107 | + |
| 108 | + |
| 109 | +def torch_compile(iterations=3): |
| 110 | + times = [] |
| 111 | + engine_cache = MyEngineCache("/tmp/your_dir") |
| 112 | + start = torch.cuda.Event(enable_timing=True) |
| 113 | + end = torch.cuda.Event(enable_timing=True) |
| 114 | + |
| 115 | + # The 1st iteration is to measure the compilation time without engine caching |
| 116 | + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. |
| 117 | + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. |
| 118 | + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. |
| 119 | + for i in range(iterations): |
| 120 | + inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] |
| 121 | + # remove timing cache and reset dynamo just for engine caching messurement |
| 122 | + remove_timing_cache() |
| 123 | + torch._dynamo.reset() |
| 124 | + |
| 125 | + if i == 0: |
| 126 | + cache_built_engines = False |
| 127 | + reuse_cached_engines = False |
| 128 | + else: |
| 129 | + cache_built_engines = True |
| 130 | + reuse_cached_engines = True |
| 131 | + |
| 132 | + start.record() |
| 133 | + compiled_model = torch.compile( |
| 134 | + model, |
| 135 | + backend="tensorrt", |
| 136 | + options={ |
| 137 | + "use_python_runtime": True, |
| 138 | + "enabled_precisions": enabled_precisions, |
| 139 | + "debug": debug, |
| 140 | + "min_block_size": min_block_size, |
| 141 | + "make_refitable": True, |
| 142 | + "cache_built_engines": cache_built_engines, |
| 143 | + "reuse_cached_engines": reuse_cached_engines, |
| 144 | + "custom_engine_cache": engine_cache, # use custom engine cache |
| 145 | + }, |
| 146 | + ) |
| 147 | + compiled_model(*inputs) # trigger the compilation |
| 148 | + end.record() |
| 149 | + torch.cuda.synchronize() |
| 150 | + times.append(start.elapsed_time(end)) |
| 151 | + |
| 152 | + print("----------------torch_compile----------------") |
| 153 | + print("disable engine caching, used:", times[0], "ms") |
| 154 | + print("enable engine caching to cache engines, used:", times[1], "ms") |
| 155 | + print("enable engine caching to reuse engines, used:", times[2], "ms") |
| 156 | + |
| 157 | + |
| 158 | +if __name__ == "__main__": |
| 159 | + dynamo_compile() |
| 160 | + torch_compile() |
0 commit comments