Skip to content

Commit cd86660

Browse files
committed
feat: Add save API for torch-trt compiled models
1 parent ad74a73 commit cd86660

File tree

7 files changed

+104
-54
lines changed

7 files changed

+104
-54
lines changed

.github/scripts/install-torch-tensorrt.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
set -eou pipefail
33
# Source conda so it's available to the script environment
44
source ${BUILD_ENV_FILE}
5-
${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision pyyaml
5+
${CONDA_RUN} ${PIP_INSTALL_TORCH} torchvision --extra-index-url https://pypi.python.org/simple
6+
${CONDA_RUN} python -m pip install pyyaml mpmath==1.3.0
67
export TRT_VERSION=$(${CONDA_RUN} python -c "import versions; versions.tensorrt_version()")
78
${CONDA_RUN} python -m pip install /opt/torch-tensorrt-builds/torch_tensorrt*+${CU_VERSION}*.whl tensorrt~=${TRT_VERSION} tensorrt-bindings~=${TRT_VERSION} --extra-index-url=https://pypi.ngc.nvidia.com
89

py/torch_tensorrt/_compile.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
import torch.fx
9+
import torch_tensorrt.dynamo
910
import torch_tensorrt.ts
1011
from torch_tensorrt._enums import dtype
1112
from torch_tensorrt._Input import Input
@@ -29,6 +30,7 @@
2930
__all__ = [
3031
"compile",
3132
"convert_method_to_trt_engine",
33+
"save",
3234
]
3335

3436

@@ -332,3 +334,68 @@ def convert_method_to_trt_engine(
332334
)
333335
else:
334336
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
337+
338+
339+
def save(
340+
module: Any,
341+
file_path: str = "",
342+
*,
343+
output_format: str = "exported_program",
344+
inputs: Optional[Sequence[torch.Tensor]] = None,
345+
retrace: bool = False,
346+
) -> None:
347+
"""
348+
Save the model to disk in the specified output format.
349+
Arguments:
350+
module : Compiled Torch-TensorRT module (Options include torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)
351+
inputs (torch.Tensor): Torch input tensors
352+
"""
353+
module_type = _parse_module_type(module)
354+
accepted_formats = {"exported_program", "torchscript"}
355+
if inputs and not all(isinstance(input, torch.Tensor) for input in inputs):
356+
raise ValueError(
357+
"Not all inputs provided are torch.tensors. Please provide torch.tensors as inputs"
358+
)
359+
if output_format not in accepted_formats:
360+
raise ValueError(
361+
f"Provided output_format {output_format} is not supported. Supported options are exported_program | torchscript"
362+
)
363+
if not file_path:
364+
raise ValueError("File path cannot be empty. Please provide a valid file path")
365+
366+
if module_type == _ModuleType.nn:
367+
raise ValueError(
368+
"Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram."
369+
)
370+
elif module_type == _ModuleType.ts:
371+
if output_format == "exported_program":
372+
raise ValueError(
373+
"Provided model is a torch.jit.ScriptModule but the output_format specified is exported_program. Please verify the output_format"
374+
)
375+
else:
376+
torch.jit.save(module, file_path)
377+
elif module_type == _ModuleType.ep:
378+
if output_format == "torchscript":
379+
raise ValueError(
380+
"Provided model is a torch.export.ExportedProgram but the output_format specified is torchscript. Please verify the output_format"
381+
)
382+
else:
383+
torch.export.save(module, file_path)
384+
elif module_type == _ModuleType.fx:
385+
if not inputs:
386+
raise ValueError(
387+
"Provided model is a torch.fx.GraphModule however the inputs are empty. Please provide valid torch.tensors as inputs to trace and save the model"
388+
)
389+
# The module type is torch.fx.GraphModule
390+
if output_format == "torchscript":
391+
module_ts = torch.jit.trace(module, inputs)
392+
torch.jit.save(module_ts, file_path)
393+
else:
394+
if not retrace:
395+
from torch_tensorrt.dynamo._exporter import export
396+
397+
exp_program = export(module, inputs)
398+
torch.export.save(exp_program, file_path)
399+
else:
400+
exp_program = torch.export.export(module, tuple(inputs), strict=False)
401+
torch.export.save(exp_program, file_path)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
MIN_BLOCK_SIZE,
3131
NUM_AVG_TIMING_ITERS,
3232
OPTIMIZATION_LEVEL,
33-
OUTPUT_FORMAT,
3433
PASS_THROUGH_BUILD_FAILURES,
3534
PRECISION,
3635
REFIT,
@@ -48,7 +47,6 @@
4847
dryrun_stats_display,
4948
parse_non_trt_nodes,
5049
)
51-
from torch_tensorrt.dynamo._exporter import export
5250
from torch_tensorrt.dynamo.conversion import (
5351
CompilationSettings,
5452
UnsupportedOperatorException,
@@ -102,9 +100,8 @@ def compile(
102100
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
103101
dryrun: bool = DRYRUN,
104102
hardware_compatible: bool = HARDWARE_COMPATIBLE,
105-
output_format: str = OUTPUT_FORMAT,
106103
**kwargs: Any,
107-
) -> Union[ExportedProgram, torch.jit.ScriptModule, torch.fx.GraphModule]:
104+
) -> torch.fx.GraphModule:
108105
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
109106
110107
Takes a existing TorchScript module and a set of settings to configure the compiler
@@ -246,14 +243,12 @@ def compile(
246243
"dla_global_dram_size": dla_global_dram_size,
247244
"dryrun": dryrun,
248245
"hardware_compatible": hardware_compatible,
249-
"output_format": output_format,
250246
}
251247

252248
settings = CompilationSettings(**compilation_options)
253249
logger.info("Compilation Settings: %s\n", settings)
254250
trt_gm = compile_module(gm, inputs, settings)
255-
trt_result = export(trt_gm, torch_inputs, output_format)
256-
return trt_result
251+
return trt_gm
257252

258253

259254
def compile_module(

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
REQUIRE_FULL_COMPILATION = False
2727
DRYRUN = False
2828
HARDWARE_COMPATIBLE = False
29-
OUTPUT_FORMAT = "exported_program"
3029

3130

3231
def default_device() -> Device:

py/torch_tensorrt/dynamo/_exporter.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,16 @@
1818
def export(
1919
gm: torch.fx.GraphModule,
2020
inputs: Sequence[torch.Tensor],
21-
output_format: str,
2221
) -> ExportedProgram:
2322
"""Export the result of TensorRT compilation into the desired output format.
2423
2524
Arguments:
2625
gm (torch.fx.GraphModule): Compiled Torch-TensorRT module, generated by ``torch_tensorrt.dynamo.compile``
2726
inputs (torch.Tensor): Torch input tensors
28-
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
2927
"""
30-
if output_format == "torchscript" or output_format == "ts":
31-
return torch.jit.trace(gm, inputs)
32-
elif output_format == "exported_program" or output_format == "ep":
33-
patched_module = transform(gm, inputs)
34-
exp_program = create_trt_exp_program(patched_module)
35-
return exp_program
36-
elif output_format == "graph_module" or output_format == "fx":
37-
return gm
38-
else:
39-
raise ValueError(
40-
f"Invalid output format {output_format} specified. Supported options include exported_program (or) ep | torchscript (or) ts | graph_module (or) fx"
41-
)
28+
patched_module = transform(gm, inputs)
29+
exp_program = create_trt_exp_program(patched_module)
30+
return exp_program
4231

4332

4433
def transform(

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
MIN_BLOCK_SIZE,
2020
NUM_AVG_TIMING_ITERS,
2121
OPTIMIZATION_LEVEL,
22-
OUTPUT_FORMAT,
2322
PASS_THROUGH_BUILD_FAILURES,
2423
PRECISION,
2524
REFIT,
@@ -71,7 +70,6 @@ class CompilationSettings:
7170
TRT Engines. Prints detailed logs of the graph structure and nature of partitioning. Optionally saves the
7271
ouptut to a file if a string path is specified
7372
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
74-
output_format (str): Output format of the result of TRT compilation. Options include "exported_program" (or) "ep" | "torchscript" (or) "ts" | "graph_module" (or) "fx". Default is "exported_program"
7573
"""
7674

7775
precision: torch.dtype = PRECISION
@@ -99,4 +97,3 @@ class CompilationSettings:
9997
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
10098
dryrun: Union[bool, str] = DRYRUN
10199
hardware_compatible: bool = HARDWARE_COMPATIBLE
102-
output_format: str = OUTPUT_FORMAT

tests/py/dynamo/models/test_export_serde.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,18 @@ def forward(self, x):
4242
}
4343

4444
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
45-
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
46-
torch.export.save(trt_exp_program, "/tmp/trt.ep")
45+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
46+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
4747
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
48-
48+
deser_trt_module = deser_trt_exp_program.module()
4949
# Check Pyt and TRT exported program outputs
50-
cos_sim = cosine_similarity(model(input), trt_exp_program(input)[0])
50+
cos_sim = cosine_similarity(model(input), trt_module(input)[0])
5151
assertions.assertTrue(
5252
cos_sim > COSINE_THRESHOLD,
5353
msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
5454
)
5555
# Check Pyt and deserialized TRT exported program outputs
56-
cos_sim = cosine_similarity(model(input), deser_trt_exp_program(input)[0])
56+
cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0])
5757
assertions.assertTrue(
5858
cos_sim > COSINE_THRESHOLD,
5959
msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
@@ -93,12 +93,13 @@ def forward(self, x):
9393
}
9494

9595
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
96-
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
97-
torch.export.save(trt_exp_program, "/tmp/trt.ep")
96+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
97+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
9898
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
99+
deser_trt_module = deser_trt_exp_program.module()
99100
# Check Pyt and TRT exported program outputs
100101
outputs_pyt = model(input)
101-
outputs_trt = trt_exp_program(input)
102+
outputs_trt = trt_module(input)
102103
for idx in range(len(outputs_pyt)):
103104
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
104105
assertions.assertTrue(
@@ -107,7 +108,7 @@ def forward(self, x):
107108
)
108109

109110
# Check Pyt and deserialized TRT exported program outputs
110-
outputs_trt_deser = deser_trt_exp_program(input)
111+
outputs_trt_deser = deser_trt_module(input)
111112
for idx in range(len(outputs_pyt)):
112113
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
113114
assertions.assertTrue(
@@ -149,12 +150,13 @@ def forward(self, x):
149150
}
150151

151152
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
152-
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
153-
torch.export.save(trt_exp_program, "/tmp/trt.ep")
153+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
154+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
154155
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
156+
deser_trt_module = deser_trt_exp_program.module()
155157
# Check Pyt and TRT exported program outputs
156158
outputs_pyt = model(input)
157-
outputs_trt = trt_exp_program(input)
159+
outputs_trt = trt_module(input)
158160
for idx in range(len(outputs_pyt)):
159161
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
160162
assertions.assertTrue(
@@ -163,7 +165,7 @@ def forward(self, x):
163165
)
164166

165167
# Check Pyt and deserialized TRT exported program outputs
166-
outputs_trt_deser = deser_trt_exp_program(input)
168+
outputs_trt_deser = deser_trt_module(input)
167169
for idx in range(len(outputs_pyt)):
168170
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
169171
assertions.assertTrue(
@@ -207,20 +209,20 @@ def forward(self, x):
207209
}
208210

209211
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
210-
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
211-
torch.export.save(trt_exp_program, "/tmp/trt.ep")
212+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
213+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
212214
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
213-
215+
deser_trt_module = deser_trt_exp_program.module()
214216
outputs_pyt = model(input)
215-
outputs_trt = trt_exp_program(input)
217+
outputs_trt = trt_module(input)
216218
for idx in range(len(outputs_pyt)):
217219
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
218220
assertions.assertTrue(
219221
cos_sim > COSINE_THRESHOLD,
220222
msg=f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
221223
)
222224

223-
outputs_trt_deser = deser_trt_exp_program(input)
225+
outputs_trt_deser = deser_trt_module(input)
224226
for idx in range(len(outputs_pyt)):
225227
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
226228
assertions.assertTrue(
@@ -248,19 +250,19 @@ def test_resnet18(ir):
248250
}
249251

250252
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
251-
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
252-
torch.export.save(trt_exp_program, "/tmp/trt.ep")
253+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
254+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
253255
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
254-
256+
deser_trt_module = deser_trt_exp_program.module()
255257
outputs_pyt = model(input)
256-
outputs_trt = trt_exp_program(input)
258+
outputs_trt = trt_module(input)
257259
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
258260
assertions.assertTrue(
259261
cos_sim > COSINE_THRESHOLD,
260262
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
261263
)
262264

263-
outputs_trt_deser = deser_trt_exp_program(input)
265+
outputs_trt_deser = deser_trt_module(input)
264266

265267
cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
266268
assertions.assertTrue(
@@ -303,12 +305,12 @@ def forward(self, x):
303305
}
304306

305307
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
306-
trt_exp_program = torchtrt.dynamo.compile(exp_program, **compile_spec)
307-
torch.export.save(trt_exp_program, "/tmp/trt.ep")
308+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
309+
torchtrt.save(trt_module, "/tmp/trt.ep", inputs=[input])
308310
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
309-
311+
deser_trt_module = deser_trt_exp_program.module()
310312
outputs_pyt = model(input)
311-
outputs_trt = trt_exp_program(input)
313+
outputs_trt = trt_module(input)
312314

313315
for idx in range(len(outputs_pyt)):
314316
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt[idx])
@@ -317,7 +319,7 @@ def forward(self, x):
317319
msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
318320
)
319321

320-
outputs_trt_deser = deser_trt_exp_program(input)
322+
outputs_trt_deser = deser_trt_module(input)
321323
for idx in range(len(outputs_pyt)):
322324
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
323325
assertions.assertTrue(

0 commit comments

Comments
 (0)