Skip to content

Commit b6a9c71

Browse files
peri044zewenli98
authored andcommitted
feat: Add save API for torch-trt compiled models (#2691)
1 parent ca4488c commit b6a9c71

File tree

6 files changed

+102
-81
lines changed

6 files changed

+102
-81
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
import torch
99
import torch.fx
10+
import torch_tensorrt.dynamo
11+
import torch_tensorrt.ts
1012
from torch_tensorrt._enums import dtype
1113
from torch_tensorrt._features import ENABLED_FEATURES
1214
from torch_tensorrt._Input import Input
@@ -371,27 +373,27 @@ def load(file_path: str = "") -> Any:
371373
try, except
372374
"""
373375
try:
374-
logger.debug(f"Loading the provided file {file_path} using torch.jit.load()")
376+
logger.debug("Loading the provided file using torch.jit.load()")
375377
ts_module = torch.jit.load(file_path)
376378
return ts_module
377379
except Exception:
378-
logger.info(
379-
f"Loading the provided file {file_path} via torch.jit.load() failed with the following error",
380+
logger.debug(
381+
"Loading the provided file via torch.jit.load() failed with the following error",
380382
exc_info=True,
381383
)
382384
pass
383385

384386
try:
385-
logger.debug(f"Loading the provided file {file_path} using torch.export.load()")
387+
logger.debug("Loading the provided file using torch.export.load()")
386388
exp_program = torch.export.load(file_path)
387389
return exp_program
388390
except Exception:
389-
logger.info(
390-
f"Loading the provided file {file_path} via torch.export.load() failed with the following error",
391+
logger.debug(
392+
"Loading the provided file via torch.export.load() failed with the following error",
391393
exc_info=True,
392394
)
393395
raise ValueError(
394-
f"The file {file_path} doesn't correspond to a valid Torchscript module or ExportedProgram. Please verify the file path."
396+
"The file doesn't correspond to a valid Torchscript module or ExportedProgram. Please verify the file path."
395397
)
396398

397399

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,33 @@
1010
from torch_tensorrt._Device import Device
1111
from torch_tensorrt._enums import EngineCapability, dtype
1212
from torch_tensorrt._Input import Input
13-
from torch_tensorrt.dynamo import _defaults, partitioning
13+
from torch_tensorrt.dynamo import partitioning
14+
from torch_tensorrt.dynamo._defaults import (
15+
DEBUG,
16+
DEVICE,
17+
DISABLE_TF32,
18+
DLA_GLOBAL_DRAM_SIZE,
19+
DLA_LOCAL_DRAM_SIZE,
20+
DLA_SRAM_SIZE,
21+
DRYRUN,
22+
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
23+
ENGINE_CAPABILITY,
24+
HARDWARE_COMPATIBLE,
25+
MAX_AUX_STREAMS,
26+
MIN_BLOCK_SIZE,
27+
NUM_AVG_TIMING_ITERS,
28+
OPTIMIZATION_LEVEL,
29+
PASS_THROUGH_BUILD_FAILURES,
30+
PRECISION,
31+
REFIT,
32+
REQUIRE_FULL_COMPILATION,
33+
SPARSE_WEIGHTS,
34+
TRUNCATE_LONG_AND_DOUBLE,
35+
USE_FAST_PARTITIONER,
36+
USE_PYTHON_RUNTIME,
37+
VERSION_COMPATIBLE,
38+
WORKSPACE_SIZE,
39+
)
1440
from torch_tensorrt.dynamo._DryRunTracker import (
1541
DryRunTracker,
1642
PerSubgraphData,
@@ -63,15 +89,15 @@ def compile(
6389
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
6490
torch_executed_ops: Optional[Collection[Target]] = None,
6591
torch_executed_modules: Optional[List[str]] = None,
66-
pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES,
67-
max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS,
68-
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
69-
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
70-
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
71-
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
72-
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
73-
dryrun: bool = _defaults.DRYRUN,
74-
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
92+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
93+
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
94+
version_compatible: bool = VERSION_COMPATIBLE,
95+
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
96+
use_python_runtime: bool = USE_PYTHON_RUNTIME,
97+
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
98+
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
99+
dryrun: bool = DRYRUN,
100+
hardware_compatible: bool = HARDWARE_COMPATIBLE,
75101
**kwargs: Any,
76102
) -> torch.fx.GraphModule:
77103
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
REQUIRE_FULL_COMPILATION = False
2727
DRYRUN = False
2828
HARDWARE_COMPATIBLE = False
29-
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8}
3029

3130

3231
def default_device() -> Device:

tests/py/dynamo/lowering/test_aten_lowering_passes.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import unittest
2-
31
import torch
42
import torch_tensorrt
53
from torch.testing._internal.common_utils import TestCase, run_tests

tests/py/dynamo/models/test_export_serde.py

Lines changed: 56 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,19 @@ def forward(self, x):
4444
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
4545
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
4646
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
47-
# TODO: Enable this serialization issues are fixed
48-
# deser_trt_module = torchtrt.load("/tmp/trt.ep").module()
47+
deser_trt_module = torchtrt.load("/tmp/trt.ep").module()
4948
# Check Pyt and TRT exported program outputs
5049
cos_sim = cosine_similarity(model(input), trt_module(input)[0])
5150
assertions.assertTrue(
5251
cos_sim > COSINE_THRESHOLD,
5352
msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
5453
)
54+
# Check Pyt and deserialized TRT exported program outputs
55+
cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0])
56+
assertions.assertTrue(
57+
cos_sim > COSINE_THRESHOLD,
58+
msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
59+
)
5560
# TODO: Enable this serialization issues are fixed
5661
# # Check Pyt and deserialized TRT exported program outputs
5762
# cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0])
@@ -95,9 +100,8 @@ def forward(self, x):
95100

96101
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
97102
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
98-
torchtrt.save(trt_module, "./trt.ep", inputs=[input])
99-
# TODO: Enable this serialization issues are fixed
100-
# deser_trt_module = torchtrt.load("./trt.ep").module()
103+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
104+
deser_trt_module = torchtrt.load("/tmp/trt.ep").module()
101105
# Check Pyt and TRT exported program outputs
102106
outputs_pyt = model(input)
103107
outputs_trt = trt_module(input)
@@ -108,15 +112,14 @@ def forward(self, x):
108112
msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
109113
)
110114

111-
# TODO: Enable this serialization issues are fixed
112-
# # Check Pyt and deserialized TRT exported program outputs
113-
# outputs_trt_deser = deser_trt_module(input)
114-
# for idx in range(len(outputs_pyt)):
115-
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
116-
# assertions.assertTrue(
117-
# cos_sim > COSINE_THRESHOLD,
118-
# msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
119-
# )
115+
# Check Pyt and deserialized TRT exported program outputs
116+
outputs_trt_deser = deser_trt_module(input)
117+
for idx in range(len(outputs_pyt)):
118+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
119+
assertions.assertTrue(
120+
cos_sim > COSINE_THRESHOLD,
121+
msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
122+
)
120123

121124

122125
@pytest.mark.unit
@@ -152,9 +155,8 @@ def forward(self, x):
152155

153156
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
154157
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
155-
torchtrt.save(trt_module, "./trt.ep", inputs=[input])
156-
# TODO: Enable this serialization issues are fixed
157-
# deser_trt_module = torchtrt.load("./trt.ep").module()
158+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
159+
deser_trt_module = torchtrt.load("/tmp/trt.ep").module()
158160
# Check Pyt and TRT exported program outputs
159161
outputs_pyt = model(input)
160162
outputs_trt = trt_module(input)
@@ -165,15 +167,14 @@ def forward(self, x):
165167
msg=f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
166168
)
167169

168-
# TODO: Enable this serialization issues are fixed
169-
# # Check Pyt and deserialized TRT exported program outputs
170-
# outputs_trt_deser = deser_trt_module(input)
171-
# for idx in range(len(outputs_pyt)):
172-
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
173-
# assertions.assertTrue(
174-
# cos_sim > COSINE_THRESHOLD,
175-
# msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
176-
# )
170+
# Check Pyt and deserialized TRT exported program outputs
171+
outputs_trt_deser = deser_trt_module(input)
172+
for idx in range(len(outputs_pyt)):
173+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
174+
assertions.assertTrue(
175+
cos_sim > COSINE_THRESHOLD,
176+
msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
177+
)
177178

178179

179180
@pytest.mark.unit
@@ -212,9 +213,8 @@ def forward(self, x):
212213

213214
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
214215
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
215-
torchtrt.save(trt_module, "./trt.ep", inputs=[input])
216-
# TODO: Enable this serialization issues are fixed
217-
# deser_trt_module = torchtrt.load("./trt.ep").module()
216+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
217+
deser_trt_module = torchtrt.load("/tmp/trt.ep").module()
218218
outputs_pyt = model(input)
219219
outputs_trt = trt_module(input)
220220
for idx in range(len(outputs_pyt)):
@@ -224,14 +224,13 @@ def forward(self, x):
224224
msg=f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
225225
)
226226

227-
# TODO: Enable this serialization issues are fixed
228-
# outputs_trt_deser = deser_trt_module(input)
229-
# for idx in range(len(outputs_pyt)):
230-
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
231-
# assertions.assertTrue(
232-
# cos_sim > COSINE_THRESHOLD,
233-
# msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
234-
# )
227+
outputs_trt_deser = deser_trt_module(input)
228+
for idx in range(len(outputs_pyt)):
229+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
230+
assertions.assertTrue(
231+
cos_sim > COSINE_THRESHOLD,
232+
msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
233+
)
235234

236235

237236
@pytest.mark.unit
@@ -254,9 +253,8 @@ def test_resnet18(ir):
254253

255254
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
256255
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
257-
torchtrt.save(trt_module, "./trt.ep", inputs=[input])
258-
# TODO: Enable this serialization issues are fixed
259-
# deser_trt_module = torchtrt.load("./trt.ep").module()
256+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
257+
deser_trt_module = torchtrt.load("/tmp/trt.ep").module()
260258
outputs_pyt = model(input)
261259
outputs_trt = trt_module(input)
262260
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
@@ -265,13 +263,13 @@ def test_resnet18(ir):
265263
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
266264
)
267265

268-
# TODO: Enable this serialization issues are fixed
269-
# outputs_trt_deser = deser_trt_module(input)
270-
# cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
271-
# assertions.assertTrue(
272-
# cos_sim > COSINE_THRESHOLD,
273-
# msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
274-
# )
266+
outputs_trt_deser = deser_trt_module(input)
267+
268+
cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
269+
assertions.assertTrue(
270+
cos_sim > COSINE_THRESHOLD,
271+
msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
272+
)
275273

276274

277275
@pytest.mark.unit
@@ -310,9 +308,8 @@ def forward(self, x):
310308
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
311309
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
312310

313-
torchtrt.save(trt_module, "./trt.ep", inputs=[input])
314-
# TODO: Enable this serialization issues are fixed
315-
# deser_trt_module = torchtrt.load("./trt.ep").module()
311+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
312+
deser_trt_module = torchtrt.load("/tmp/trt.ep").module()
316313
outputs_pyt = model(input)
317314
outputs_trt = trt_module(input)
318315

@@ -323,14 +320,13 @@ def forward(self, x):
323320
msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
324321
)
325322

326-
# TODO: Enable this serialization issues are fixed
327-
# outputs_trt_deser = deser_trt_module(input)
328-
# for idx in range(len(outputs_pyt)):
329-
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
330-
# assertions.assertTrue(
331-
# cos_sim > COSINE_THRESHOLD,
332-
# msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
333-
# )
323+
outputs_trt_deser = deser_trt_module(input)
324+
for idx in range(len(outputs_pyt)):
325+
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
326+
assertions.assertTrue(
327+
cos_sim > COSINE_THRESHOLD,
328+
msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
329+
)
334330

335331

336332
@pytest.mark.unit
@@ -361,9 +357,9 @@ def forward(self, x):
361357
)
362358
outputs_trt = trt_gm(input)
363359
# Save it as torchscript representation
364-
torchtrt.save(trt_gm, "./trt.ts", output_format="torchscript", inputs=[input])
360+
torchtrt.save(trt_gm, "/tmp/trt.ts", output_format="torchscript", inputs=[input])
365361

366-
trt_ts_module = torchtrt.load("./trt.ts")
362+
trt_ts_module = torchtrt.load("/tmp/trt.ts")
367363
outputs_trt_deser = trt_ts_module(input)
368364

369365
cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser)

tests/py/dynamo/models/test_models_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_bert_base_uncased(ir):
129129
"enabled_precisions": {torch.float},
130130
"truncate_long_and_double": True,
131131
"ir": ir,
132-
"min_block_size": 10,
132+
"min_block_size": 15,
133133
}
134134
trt_mod = torchtrt.compile(model, **compile_spec)
135135
model_outputs = model(input, input2)

0 commit comments

Comments
 (0)