Skip to content

Commit bc95015

Browse files
cross compile for windows (#3220)
1 parent bbff652 commit bc95015

File tree

11 files changed

+763
-26
lines changed

11 files changed

+763
-26
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
.. _resnet_cross_runtime_compilation_for_windows_example:
3+
4+
cross runtime compilation limitations:
5+
The cross compile and saved model can only be loaded in Windows, it can no longer be loaded in Linux
6+
The cross compile and saved model can only be loaded in the same Compute Capability as the Linux which it was cross compiled
7+
(for example, if the model was cross compiled in Linux with GeForceRTX 4080 which has Compute Capability of 8.9,
8+
It cannot be loaded in Windows with GeForceRTX 3080 which has Compute Capability of 8.6)
9+
10+
Cross runtime compilation for windows example
11+
======================================================
12+
13+
Compile and save the Resnet Model using Torch-TensorRT in Linux:
14+
15+
python examples/dynamo/cross_runtime_compilation_for_windows.py --path trt_resnet.ep
16+
17+
Load the Resnet Model saved in Windows:
18+
19+
python examples/dynamo/cross_runtime_compilation_for_windows.py --path trt_resnet.ep --load True
20+
21+
"""
22+
23+
# %%
24+
# Imports and Model Definition
25+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26+
27+
import argparse
28+
import platform
29+
30+
import torch
31+
import torch_tensorrt as torchtrt
32+
import torchvision.models as models
33+
34+
PARSER = argparse.ArgumentParser(
35+
description="Cross runtime comilation for windows example: Resnet Model"
36+
)
37+
PARSER.add_argument(
38+
"--load", default=False, type=bool, required=False, help="Load the model in Windows"
39+
)
40+
PARSER.add_argument(
41+
"--path",
42+
type=str,
43+
required=True,
44+
help="Path to the saved model file",
45+
)
46+
47+
args = PARSER.parse_args()
48+
torch.manual_seed(0)
49+
model = models.resnet18().eval().cuda()
50+
input = torch.rand((1, 3, 224, 224)).to("cuda")
51+
inputs = [input]
52+
53+
# %%
54+
# According to the argument, it is either cross compile and save resnet model for windows in Linux
55+
# or load the saved resnet model in Windows
56+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
57+
if args.load:
58+
# load the saved model in Windows
59+
if platform.system() != "Windows" or platform.machine() != "AMD64":
60+
raise ValueError(
61+
"cross runtime compiled model for windows can only be loaded in Windows system"
62+
)
63+
loaded_model = torchtrt.load_cross_compiled_exported_program(args.path).module()
64+
print(f"model has been successfully loaded from ${args.path}")
65+
# inference
66+
trt_output = loaded_model(input)
67+
print(f"inference result: {trt_output}")
68+
else:
69+
if platform.system() != "Linux" or platform.architecture()[0] != "64bit":
70+
raise ValueError(
71+
"cross runtime compiled model for windows can only be compiled in Linux system"
72+
)
73+
compile_spec = {
74+
"debug": True,
75+
"min_block_size": 1,
76+
}
77+
torchtrt.cross_compile_for_windows(
78+
model, file_path=args.path, inputs=inputs, **compile_spec
79+
)
80+
print(
81+
f"model has been successfully cross compiled and saved in Linux to {args.path}"
82+
)

py/torch_tensorrt/_compile.py

Lines changed: 130 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import logging
5+
import platform
56
from enum import Enum
67
from typing import Any, Callable, List, Optional, Sequence, Set
78

@@ -29,11 +30,27 @@
2930
from torch_tensorrt.dynamo._compiler import (
3031
convert_exported_program_to_serialized_trt_engine as dynamo_convert_exported_program_to_serialized_trt_engine,
3132
)
33+
from torch_tensorrt.dynamo._compiler import (
34+
cross_compile_for_windows as dynamo_cross_compile_for_windows,
35+
)
36+
from torch_tensorrt.dynamo._compiler import (
37+
load_cross_compiled_exported_program as dynamo_load_cross_compiled_exported_program,
38+
)
39+
from torch_tensorrt.dynamo._compiler import (
40+
save_cross_compiled_exported_program as dynamo_save_cross_compiled_exported_program,
41+
)
3242
from torch_tensorrt.dynamo._tracer import trace as dynamo_trace
3343

3444
logger = logging.getLogger(__name__)
3545

36-
__all__ = ["compile", "convert_method_to_trt_engine", "save", "load"]
46+
__all__ = [
47+
"compile",
48+
"cross_compile_for_windows",
49+
"load_cross_compiled_exported_program",
50+
"convert_method_to_trt_engine",
51+
"save",
52+
"load",
53+
]
3754

3855

3956
def _non_fx_input_interface(
@@ -281,6 +298,105 @@ def compile(
281298
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
282299

283300

301+
def cross_compile_for_windows(
302+
module: torch.nn.Module,
303+
file_path: str,
304+
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
305+
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
306+
kwarg_inputs: Optional[dict[Any, Any]] = None,
307+
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
308+
**kwargs: Any,
309+
) -> None:
310+
"""Compile a PyTorch module using TensorRT in Linux for Inference in Windows
311+
312+
Takes an existing PyTorch module and a set of settings to configure the compiler
313+
and it will convert methods to AOT graphs which call equivalent TensorRT serialized
314+
engine info into the disk in the specified file_path user provided.
315+
It will then allow user to load the deserialized model from the disk in Windows.
316+
Note: the model cross compiled for windows in Linux environmen can only be loaded
317+
in Windows.
318+
319+
Argument:
320+
module (torch.nn.Module): Source module
321+
file_path (str): the file path to store the serialized module into the disk
322+
323+
Keyword Arguments:
324+
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
325+
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
326+
to select device type. ::
327+
328+
inputs=[
329+
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
330+
torch_tensorrt.Input(
331+
min_shape=(1, 224, 224, 3),
332+
opt_shape=(1, 512, 512, 3),
333+
max_shape=(1, 1024, 1024, 3),
334+
dtype=torch.int32
335+
format=torch.channel_last
336+
), # Dynamic input shape for input #2
337+
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
338+
]
339+
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
340+
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
341+
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
342+
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)
343+
344+
"""
345+
346+
if platform.system() != "Linux" or platform.architecture()[0] != "64bit":
347+
raise RuntimeError(
348+
f"Cross compile for windows is only supported on x86-64 Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}"
349+
)
350+
351+
if not file_path:
352+
raise ValueError("File path cannot be empty. Please provide a valid file path")
353+
354+
enabled_precisions_set: Set[dtype | torch.dtype] = (
355+
enabled_precisions
356+
if enabled_precisions is not None
357+
else _defaults.ENABLED_PRECISIONS
358+
)
359+
360+
# Prepare torch and torchtrt inputs
361+
if not arg_inputs and not inputs:
362+
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
363+
364+
elif arg_inputs and inputs:
365+
raise AssertionError(
366+
"'arg_inputs' and 'inputs' should not be used at the same time."
367+
)
368+
369+
arg_inputs = inputs or arg_inputs
370+
371+
if kwarg_inputs is None:
372+
kwarg_inputs = {}
373+
374+
from torch_tensorrt.dynamo.utils import prepare_inputs
375+
376+
if not isinstance(arg_inputs, collections.abc.Sequence):
377+
arg_inputs = [arg_inputs] # type: ignore
378+
379+
# Export the module
380+
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
381+
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
382+
383+
exp_program = dynamo_trace(
384+
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
385+
)
386+
logger.debug("successfully exported the module")
387+
388+
# Compile and save the module
389+
trt_gm = dynamo_cross_compile_for_windows(
390+
exp_program,
391+
arg_inputs=torchtrt_arg_inputs,
392+
enabled_precisions=enabled_precisions_set,
393+
**kwargs,
394+
)
395+
396+
dynamo_save_cross_compiled_exported_program(trt_gm, file_path)
397+
logger.debug("successfully compiled and saved the module for windows")
398+
399+
284400
def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
285401
"""
286402
Returns a boxed model which is the output of torch.compile.
@@ -406,6 +522,19 @@ def convert_method_to_trt_engine(
406522
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
407523

408524

525+
def load_cross_compiled_exported_program(file_path: str = "") -> Any:
526+
"""
527+
Load an ExportedProgram file in Windows which was previously cross compiled in Linux
528+
529+
Arguments:
530+
file_path (str): Path to file on the disk
531+
532+
Raises:
533+
ValueError: If the api is not called in windows or there is no file or the file is not a valid ExportedProgram file
534+
"""
535+
return dynamo_load_cross_compiled_exported_program(file_path)
536+
537+
409538
def load(file_path: str = "") -> Any:
410539
"""
411540
Load either a Torchscript model or ExportedProgram.

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,13 @@
77
logger = logging.getLogger(__name__)
88

99
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
10-
from ._compiler import compile, convert_exported_program_to_serialized_trt_engine
10+
from ._compiler import (
11+
compile,
12+
convert_exported_program_to_serialized_trt_engine,
13+
cross_compile_for_windows,
14+
load_cross_compiled_exported_program,
15+
save_cross_compiled_exported_program,
16+
)
1117
from ._exporter import export
1218
from ._refit import refit_module_weights
1319
from ._settings import CompilationSettings

0 commit comments

Comments
 (0)