Skip to content

cross compile for windows #3220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 57 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
458a4d1
skip run_shape_analysis
lanluo-nvidia Oct 6, 2024
2f408f9
test
lanluo-nvidia Oct 6, 2024
1c5e86c
test
lanluo-nvidia Oct 6, 2024
ba487dc
test
lanluo-nvidia Oct 6, 2024
99d2274
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 6, 2024
2b43480
test
lanluo-nvidia Oct 6, 2024
92105ce
cross compile for windows initial check in
lanluo-nvidia Oct 7, 2024
dd24bdf
test
lanluo-nvidia Oct 7, 2024
e6c9fa4
add testcase
lanluo-nvidia Oct 7, 2024
9909eca
add test case
lanluo-nvidia Oct 8, 2024
312a79e
add test case
lanluo-nvidia Oct 8, 2024
b4e02e1
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 11, 2024
3d94f8b
test
lanluo-nvidia Oct 13, 2024
f43669a
Merge branch 'lluo/save_remove_inputs' into lluo/cross_compilation_fo…
lanluo-nvidia Oct 13, 2024
f4d0d27
test
lanluo-nvidia Oct 14, 2024
2a43ca1
add more logs
lanluo-nvidia Oct 14, 2024
76ddd66
test
lanluo-nvidia Oct 15, 2024
28ba6cc
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 15, 2024
b89cbe0
resolve comments
lanluo-nvidia Oct 15, 2024
079c4be
test
lanluo-nvidia Oct 15, 2024
90f1a60
clean up
lanluo-nvidia Oct 16, 2024
c38dc5b
test
lanluo-nvidia Oct 16, 2024
2843d37
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 16, 2024
b06a41b
Merge branch 'lluo/save_remove_inputs' into lluo/cross_compilation_fo…
lanluo-nvidia Oct 16, 2024
3eb48d7
test
lanluo-nvidia Oct 16, 2024
fedc5c2
Merge branch 'lluo/save_remove_inputs' into lluo/cross_compilation_fo…
lanluo-nvidia Oct 16, 2024
50eb0d8
replace dummy inference
lanluo-nvidia Oct 20, 2024
95ed602
test
lanluo-nvidia Oct 20, 2024
120f30d
test
lanluo-nvidia Oct 21, 2024
424cbf7
add run_test_with_dynamic_shape change
lanluo-nvidia Oct 21, 2024
2fc9cef
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 21, 2024
ef54cfc
split the PR, add dummy inference for converter test
lanluo-nvidia Oct 21, 2024
14f5d61
test
lanluo-nvidia Oct 22, 2024
7563959
test
lanluo-nvidia Oct 22, 2024
77355f0
test
lanluo-nvidia Oct 22, 2024
13361fd
add linear lowering meta val
lanluo-nvidia Oct 22, 2024
f0a9fef
add linear_lowering change
lanluo-nvidia Oct 23, 2024
cff64a4
test
lanluo-nvidia Oct 23, 2024
933abac
test
lanluo-nvidia Oct 23, 2024
8417684
resolve comments
lanluo-nvidia Oct 25, 2024
8676f88
test
lanluo-nvidia Oct 25, 2024
785d0b1
change solution: use no_op_placeholder during save and replace it wit…
lanluo-nvidia Oct 26, 2024
b85cb74
Merge branch 'lluo/save_remove_inputs' into lluo/cross_compilation_fo…
lanluo-nvidia Oct 26, 2024
5d594b1
test
lanluo-nvidia Oct 26, 2024
076f47a
resolve comments
lanluo-nvidia Oct 29, 2024
8250179
Merge branch 'main' into lluo/save_remove_inputs
lanluo-nvidia Oct 29, 2024
96e93e4
resolve comments
lanluo-nvidia Oct 29, 2024
7d055bf
Merge branch 'lluo/save_remove_inputs' into lluo/cross_compilation_fo…
lanluo-nvidia Oct 30, 2024
6e214d6
Merge branch 'main' into lluo/cross_compilation_for_windows
lanluo-nvidia Oct 30, 2024
cf0e0ae
resolve comments
lanluo-nvidia Oct 30, 2024
fb7fd1d
resolve comments
lanluo-nvidia Oct 31, 2024
8fc89dc
resolve comments
lanluo-nvidia Oct 31, 2024
63b3376
Merge branch 'main' into lluo/cross_compilation_for_windows
lanluo-nvidia Nov 1, 2024
1a1ba41
merge latest main
lanluo-nvidia Nov 1, 2024
96bc256
fix linting
lanluo-nvidia Nov 1, 2024
eebe1bc
test
lanluo-nvidia Nov 1, 2024
a35a932
resolve comments
lanluo-nvidia Nov 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions examples/dynamo/cross_runtime_compilation_for_windows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
.. _resnet_cross_runtime_compilation_for_windows_example:

cross runtime compilation limitations:
The cross compile and saved model can only be loaded in Windows, it can no longer be loaded in Linux
The cross compile and saved model can only be loaded in the same Compute Capability as the Linux which it was cross compiled
(for example, if the model was cross compiled in Linux with GeForceRTX 4080 which has Compute Capability of 8.9,
It cannot be loaded in Windows with GeForceRTX 3080 which has Compute Capability of 8.6)

Cross runtime compilation for windows example
======================================================

Compile and save the Resnet Model using Torch-TensorRT in Linux:

python examples/dynamo/cross_runtime_compilation_for_windows.py --path trt_resnet.ep

Load the Resnet Model saved in Windows:

python examples/dynamo/cross_runtime_compilation_for_windows.py --path trt_resnet.ep --load True

"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import argparse
import platform

import torch
import torch_tensorrt as torchtrt
import torchvision.models as models

PARSER = argparse.ArgumentParser(
description="Cross runtime comilation for windows example: Resnet Model"
)
PARSER.add_argument(
"--load", default=False, type=bool, required=False, help="Load the model in Windows"
)
PARSER.add_argument(
"--path",
type=str,
required=True,
help="Path to the saved model file",
)

args = PARSER.parse_args()
torch.manual_seed(0)
model = models.resnet18().eval().cuda()
input = torch.rand((1, 3, 224, 224)).to("cuda")
inputs = [input]

# %%
# According to the argument, it is either cross compile and save resnet model for windows in Linux
# or load the saved resnet model in Windows
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
if args.load:
# load the saved model in Windows
if platform.system() != "Windows" or platform.machine() != "AMD64":
raise ValueError(
"cross runtime compiled model for windows can only be loaded in Windows system"
)
loaded_model = torchtrt.load_cross_compiled_exported_program(args.path).module()
print(f"model has been successfully loaded from ${args.path}")
# inference
trt_output = loaded_model(input)
print(f"inference result: {trt_output}")
else:
if platform.system() != "Linux" or platform.architecture()[0] != "64bit":
raise ValueError(
"cross runtime compiled model for windows can only be compiled in Linux system"
)
compile_spec = {
"debug": True,
"min_block_size": 1,
}
torchtrt.cross_compile_for_windows(
model, file_path=args.path, inputs=inputs, **compile_spec
)
print(
f"model has been successfully cross compiled and saved in Linux to {args.path}"
)
131 changes: 130 additions & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import collections.abc
import logging
import platform
from enum import Enum
from typing import Any, Callable, List, Optional, Sequence, Set

Expand Down Expand Up @@ -29,11 +30,27 @@
from torch_tensorrt.dynamo._compiler import (
convert_exported_program_to_serialized_trt_engine as dynamo_convert_exported_program_to_serialized_trt_engine,
)
from torch_tensorrt.dynamo._compiler import (
cross_compile_for_windows as dynamo_cross_compile_for_windows,
)
from torch_tensorrt.dynamo._compiler import (
load_cross_compiled_exported_program as dynamo_load_cross_compiled_exported_program,
)
from torch_tensorrt.dynamo._compiler import (
save_cross_compiled_exported_program as dynamo_save_cross_compiled_exported_program,
)
from torch_tensorrt.dynamo._tracer import trace as dynamo_trace

logger = logging.getLogger(__name__)

__all__ = ["compile", "convert_method_to_trt_engine", "save", "load"]
__all__ = [
"compile",
"cross_compile_for_windows",
"load_cross_compiled_exported_program",
"convert_method_to_trt_engine",
"save",
"load",
]


def _non_fx_input_interface(
Expand Down Expand Up @@ -281,6 +298,105 @@ def compile(
raise RuntimeError("Module is an unknown format or the ir requested is unknown")


def cross_compile_for_windows(
module: torch.nn.Module,
file_path: str,
inputs: Optional[Sequence[Input | torch.Tensor]] = None,
arg_inputs: Optional[Sequence[Sequence[Any]]] = None,
kwarg_inputs: Optional[dict[Any, Any]] = None,
enabled_precisions: Optional[Set[torch.dtype | dtype]] = None,
**kwargs: Any,
) -> None:
"""Compile a PyTorch module using TensorRT in Linux for Inference in Windows

Takes an existing PyTorch module and a set of settings to configure the compiler
and it will convert methods to AOT graphs which call equivalent TensorRT serialized
engine info into the disk in the specified file_path user provided.
It will then allow user to load the deserialized model from the disk in Windows.
Note: the model cross compiled for windows in Linux environmen can only be loaded
in Windows.

Argument:
module (torch.nn.Module): Source module
file_path (str): the file path to store the serialized module into the disk

Keyword Arguments:
inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum
to select device type. ::

inputs=[
torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1
torch_tensorrt.Input(
min_shape=(1, 224, 224, 3),
opt_shape=(1, 512, 512, 3),
max_shape=(1, 1024, 1024, 3),
dtype=torch.int32
format=torch.channel_last
), # Dynamic input shape for input #2
torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings
]
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
**kwargs: Additional settings for the specific requested strategy (See submodules for more info)

"""

if platform.system() != "Linux" or platform.architecture()[0] != "64bit":
raise RuntimeError(
f"Cross compile for windows is only supported on x86-64 Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}"
)

if not file_path:
raise ValueError("File path cannot be empty. Please provide a valid file path")

enabled_precisions_set: Set[dtype | torch.dtype] = (
enabled_precisions
if enabled_precisions is not None
else _defaults.ENABLED_PRECISIONS
)

# Prepare torch and torchtrt inputs
if not arg_inputs and not inputs:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")

elif arg_inputs and inputs:
raise AssertionError(
"'arg_inputs' and 'inputs' should not be used at the same time."
)

arg_inputs = inputs or arg_inputs

if kwarg_inputs is None:
kwarg_inputs = {}

from torch_tensorrt.dynamo.utils import prepare_inputs

if not isinstance(arg_inputs, collections.abc.Sequence):
arg_inputs = [arg_inputs] # type: ignore

# Export the module
torchtrt_arg_inputs = prepare_inputs(arg_inputs)
torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)

exp_program = dynamo_trace(
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
)
logger.debug("successfully exported the module")

# Compile and save the module
trt_gm = dynamo_cross_compile_for_windows(
exp_program,
arg_inputs=torchtrt_arg_inputs,
enabled_precisions=enabled_precisions_set,
**kwargs,
)

dynamo_save_cross_compiled_exported_program(trt_gm, file_path)
logger.debug("successfully compiled and saved the module for windows")


def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
"""
Returns a boxed model which is the output of torch.compile.
Expand Down Expand Up @@ -406,6 +522,19 @@ def convert_method_to_trt_engine(
raise RuntimeError("Module is an unknown format or the ir requested is unknown")


def load_cross_compiled_exported_program(file_path: str = "") -> Any:
"""
Load an ExportedProgram file in Windows which was previously cross compiled in Linux

Arguments:
file_path (str): Path to file on the disk

Raises:
ValueError: If the api is not called in windows or there is no file or the file is not a valid ExportedProgram file
"""
return dynamo_load_cross_compiled_exported_program(file_path)


def load(file_path: str = "") -> Any:
"""
Load either a Torchscript model or ExportedProgram.
Expand Down
8 changes: 7 additions & 1 deletion py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
logger = logging.getLogger(__name__)

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from ._compiler import compile, convert_exported_program_to_serialized_trt_engine
from ._compiler import (
compile,
convert_exported_program_to_serialized_trt_engine,
cross_compile_for_windows,
load_cross_compiled_exported_program,
save_cross_compiled_exported_program,
)
from ._exporter import export
from ._refit import refit_module_weights
from ._settings import CompilationSettings
Expand Down
Loading
Loading