|
| 1 | +# type: ignore |
| 2 | +import os |
| 3 | +import shutil |
| 4 | +import unittest |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch_tensorrt as torch_trt |
| 9 | +import torchvision.models as models |
| 10 | +from torch.testing._internal.common_utils import TestCase |
| 11 | +from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR |
| 12 | +from torch_tensorrt.dynamo._engine_caching import BaseEngineCache |
| 13 | +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity |
| 14 | + |
| 15 | +assertions = unittest.TestCase() |
| 16 | + |
| 17 | + |
| 18 | +class MyEngineCache(BaseEngineCache): |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + engine_cache_dir: str, |
| 22 | + ) -> None: |
| 23 | + self.engine_cache_dir = engine_cache_dir |
| 24 | + |
| 25 | + def save( |
| 26 | + self, |
| 27 | + hash: str, |
| 28 | + blob: bytes, |
| 29 | + prefix: str = "blob", |
| 30 | + ): |
| 31 | + if not os.path.exists(self.engine_cache_dir): |
| 32 | + os.makedirs(self.engine_cache_dir, exist_ok=True) |
| 33 | + |
| 34 | + path = os.path.join( |
| 35 | + self.engine_cache_dir, |
| 36 | + f"{prefix}_{hash}.bin", |
| 37 | + ) |
| 38 | + with open(path, "wb") as f: |
| 39 | + f.write(blob) |
| 40 | + |
| 41 | + def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]: |
| 42 | + path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin") |
| 43 | + if os.path.exists(path): |
| 44 | + with open(path, "rb") as f: |
| 45 | + blob = f.read() |
| 46 | + return blob |
| 47 | + return None |
| 48 | + |
| 49 | + |
| 50 | +class TestEngineCache(TestCase): |
| 51 | + |
| 52 | + def test_dynamo_compile(self): |
| 53 | + model = models.resnet18(pretrained=True).eval().to("cuda") |
| 54 | + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) |
| 55 | + # Mark the dim0 of inputs as dynamic |
| 56 | + batch = torch.export.Dim("batch", min=1, max=200) |
| 57 | + exp_program = torch.export.export( |
| 58 | + model, args=example_inputs, dynamic_shapes={"x": {0: batch}} |
| 59 | + ) |
| 60 | + engine_cache_dir = ENGINE_CACHE_DIR |
| 61 | + if os.path.exists(engine_cache_dir): |
| 62 | + shutil.rmtree(engine_cache_dir) |
| 63 | + # The 1st iteration is to measure the compilation time without engine caching |
| 64 | + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. |
| 65 | + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. |
| 66 | + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. |
| 67 | + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] |
| 68 | + results = [] |
| 69 | + for i in range(3): |
| 70 | + if i == 0: |
| 71 | + cache_built_engines = False |
| 72 | + reuse_cached_engines = False |
| 73 | + else: |
| 74 | + cache_built_engines = True |
| 75 | + reuse_cached_engines = True |
| 76 | + |
| 77 | + trt_gm = torch_trt.dynamo.compile( |
| 78 | + exp_program, |
| 79 | + tuple(inputs), |
| 80 | + use_python_runtime=False, |
| 81 | + enabled_precisions={torch.float}, |
| 82 | + debug=False, |
| 83 | + min_block_size=1, |
| 84 | + make_refitable=True, |
| 85 | + cache_built_engines=cache_built_engines, |
| 86 | + reuse_cached_engines=reuse_cached_engines, |
| 87 | + engine_cache_size=1 << 30, # 1GB |
| 88 | + ) |
| 89 | + results.append(trt_gm(*inputs)) |
| 90 | + |
| 91 | + cos_sim = cosine_similarity(results[0], results[1]) |
| 92 | + assertions.assertTrue( |
| 93 | + cos_sim > COSINE_THRESHOLD, |
| 94 | + msg=f"test_dynamo_compile TRT without engine caching doesn't match with that with engine caching. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", |
| 95 | + ) |
| 96 | + |
| 97 | + cos_sim = cosine_similarity(results[1], results[2]) |
| 98 | + assertions.assertTrue( |
| 99 | + cos_sim > COSINE_THRESHOLD, |
| 100 | + msg=f"test_dynamo_compile TRT with engine caching doesn't match with that cached engine. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", |
| 101 | + ) |
| 102 | + |
| 103 | + def test_torch_compile(self): |
| 104 | + # Custom Engine Cache |
| 105 | + model = models.resnet18(pretrained=True).eval().to("cuda") |
| 106 | + |
| 107 | + engine_cache_dir = "/tmp/your_dir" |
| 108 | + if os.path.exists(engine_cache_dir): |
| 109 | + shutil.rmtree(engine_cache_dir) |
| 110 | + |
| 111 | + engine_cache = MyEngineCache(engine_cache_dir) |
| 112 | + # The 1st iteration is to measure the compilation time without engine caching |
| 113 | + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. |
| 114 | + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. |
| 115 | + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. |
| 116 | + inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] |
| 117 | + results = [] |
| 118 | + for i in range(3): |
| 119 | + # remove timing cache and reset dynamo for engine caching messurement |
| 120 | + if i == 0: |
| 121 | + cache_built_engines = False |
| 122 | + reuse_cached_engines = False |
| 123 | + else: |
| 124 | + cache_built_engines = True |
| 125 | + reuse_cached_engines = True |
| 126 | + |
| 127 | + compiled_model = torch.compile( |
| 128 | + model, |
| 129 | + backend="tensorrt", |
| 130 | + options={ |
| 131 | + "use_python_runtime": True, |
| 132 | + "enabled_precisions": {torch.float}, |
| 133 | + "debug": False, |
| 134 | + "min_block_size": 1, |
| 135 | + "make_refitable": True, |
| 136 | + "cache_built_engines": cache_built_engines, |
| 137 | + "reuse_cached_engines": reuse_cached_engines, |
| 138 | + "custom_engine_cache": engine_cache, # use custom engine cache |
| 139 | + }, |
| 140 | + ) |
| 141 | + results.append(compiled_model(*inputs)) # trigger the compilation |
| 142 | + |
| 143 | + cos_sim = cosine_similarity(results[0], results[1]) |
| 144 | + assertions.assertTrue( |
| 145 | + cos_sim > COSINE_THRESHOLD, |
| 146 | + msg=f"test_torch_compile TRT without engine caching doesn't match with that with engine caching. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", |
| 147 | + ) |
| 148 | + |
| 149 | + cos_sim = cosine_similarity(results[1], results[2]) |
| 150 | + assertions.assertTrue( |
| 151 | + cos_sim > COSINE_THRESHOLD, |
| 152 | + msg=f"test_torch_compile TRT with engine caching doesn't match with that cached engine. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", |
| 153 | + ) |
0 commit comments