|
2 | 2 |
|
3 | 3 | import collections.abc
|
4 | 4 | import logging
|
| 5 | +import platform |
5 | 6 | from enum import Enum
|
6 | 7 | from typing import Any, Callable, List, Optional, Sequence, Set
|
7 | 8 |
|
|
29 | 30 | from torch_tensorrt.dynamo._compiler import (
|
30 | 31 | convert_exported_program_to_serialized_trt_engine as dynamo_convert_exported_program_to_serialized_trt_engine,
|
31 | 32 | )
|
| 33 | + from torch_tensorrt.dynamo._compiler import ( |
| 34 | + cross_compile_for_windows as dynamo_cross_compile_for_windows, |
| 35 | + ) |
| 36 | + from torch_tensorrt.dynamo._compiler import ( |
| 37 | + load_cross_compiled_exported_program as dynamo_load_cross_compiled_exported_program, |
| 38 | + ) |
| 39 | + from torch_tensorrt.dynamo._compiler import ( |
| 40 | + save_cross_compiled_exported_program as dynamo_save_cross_compiled_exported_program, |
| 41 | + ) |
32 | 42 | from torch_tensorrt.dynamo._tracer import trace as dynamo_trace
|
33 | 43 |
|
34 | 44 | logger = logging.getLogger(__name__)
|
35 | 45 |
|
36 |
| -__all__ = ["compile", "convert_method_to_trt_engine", "save", "load"] |
| 46 | +__all__ = [ |
| 47 | + "compile", |
| 48 | + "cross_compile_for_windows", |
| 49 | + "load_cross_compiled_exported_program", |
| 50 | + "convert_method_to_trt_engine", |
| 51 | + "save", |
| 52 | + "load", |
| 53 | +] |
37 | 54 |
|
38 | 55 |
|
39 | 56 | def _non_fx_input_interface(
|
@@ -281,6 +298,105 @@ def compile(
|
281 | 298 | raise RuntimeError("Module is an unknown format or the ir requested is unknown")
|
282 | 299 |
|
283 | 300 |
|
| 301 | +def cross_compile_for_windows( |
| 302 | + module: torch.nn.Module, |
| 303 | + file_path: str, |
| 304 | + inputs: Optional[Sequence[Input | torch.Tensor]] = None, |
| 305 | + arg_inputs: Optional[Sequence[Sequence[Any]]] = None, |
| 306 | + kwarg_inputs: Optional[dict[Any, Any]] = None, |
| 307 | + enabled_precisions: Optional[Set[torch.dtype | dtype]] = None, |
| 308 | + **kwargs: Any, |
| 309 | +) -> None: |
| 310 | + """Compile a PyTorch module using TensorRT in Linux for Inference in Windows |
| 311 | +
|
| 312 | + Takes an existing PyTorch module and a set of settings to configure the compiler |
| 313 | + and it will convert methods to AOT graphs which call equivalent TensorRT serialized |
| 314 | + engine info into the disk in the specified file_path user provided. |
| 315 | + It will then allow user to load the deserialized model from the disk in Windows. |
| 316 | + Note: the model cross compiled for windows in Linux environmen can only be loaded |
| 317 | + in Windows. |
| 318 | +
|
| 319 | + Argument: |
| 320 | + module (torch.nn.Module): Source module |
| 321 | + file_path (str): the file path to store the serialized module into the disk |
| 322 | +
|
| 323 | + Keyword Arguments: |
| 324 | + inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using |
| 325 | + torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum |
| 326 | + to select device type. :: |
| 327 | +
|
| 328 | + inputs=[ |
| 329 | + torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 |
| 330 | + torch_tensorrt.Input( |
| 331 | + min_shape=(1, 224, 224, 3), |
| 332 | + opt_shape=(1, 512, 512, 3), |
| 333 | + max_shape=(1, 1024, 1024, 3), |
| 334 | + dtype=torch.int32 |
| 335 | + format=torch.channel_last |
| 336 | + ), # Dynamic input shape for input #2 |
| 337 | + torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings |
| 338 | + ] |
| 339 | + arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. |
| 340 | + kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. |
| 341 | + enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels |
| 342 | + **kwargs: Additional settings for the specific requested strategy (See submodules for more info) |
| 343 | +
|
| 344 | + """ |
| 345 | + |
| 346 | + if platform.system() != "Linux" or platform.architecture()[0] != "64bit": |
| 347 | + raise RuntimeError( |
| 348 | + f"Cross compile for windows is only supported on x86-64 Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}" |
| 349 | + ) |
| 350 | + |
| 351 | + if not file_path: |
| 352 | + raise ValueError("File path cannot be empty. Please provide a valid file path") |
| 353 | + |
| 354 | + enabled_precisions_set: Set[dtype | torch.dtype] = ( |
| 355 | + enabled_precisions |
| 356 | + if enabled_precisions is not None |
| 357 | + else _defaults.ENABLED_PRECISIONS |
| 358 | + ) |
| 359 | + |
| 360 | + # Prepare torch and torchtrt inputs |
| 361 | + if not arg_inputs and not inputs: |
| 362 | + raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") |
| 363 | + |
| 364 | + elif arg_inputs and inputs: |
| 365 | + raise AssertionError( |
| 366 | + "'arg_inputs' and 'inputs' should not be used at the same time." |
| 367 | + ) |
| 368 | + |
| 369 | + arg_inputs = inputs or arg_inputs |
| 370 | + |
| 371 | + if kwarg_inputs is None: |
| 372 | + kwarg_inputs = {} |
| 373 | + |
| 374 | + from torch_tensorrt.dynamo.utils import prepare_inputs |
| 375 | + |
| 376 | + if not isinstance(arg_inputs, collections.abc.Sequence): |
| 377 | + arg_inputs = [arg_inputs] # type: ignore |
| 378 | + |
| 379 | + # Export the module |
| 380 | + torchtrt_arg_inputs = prepare_inputs(arg_inputs) |
| 381 | + torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs) |
| 382 | + |
| 383 | + exp_program = dynamo_trace( |
| 384 | + module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs |
| 385 | + ) |
| 386 | + logger.debug("successfully exported the module") |
| 387 | + |
| 388 | + # Compile and save the module |
| 389 | + trt_gm = dynamo_cross_compile_for_windows( |
| 390 | + exp_program, |
| 391 | + arg_inputs=torchtrt_arg_inputs, |
| 392 | + enabled_precisions=enabled_precisions_set, |
| 393 | + **kwargs, |
| 394 | + ) |
| 395 | + |
| 396 | + dynamo_save_cross_compiled_exported_program(trt_gm, file_path) |
| 397 | + logger.debug("successfully compiled and saved the module for windows") |
| 398 | + |
| 399 | + |
284 | 400 | def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Any:
|
285 | 401 | """
|
286 | 402 | Returns a boxed model which is the output of torch.compile.
|
@@ -406,6 +522,19 @@ def convert_method_to_trt_engine(
|
406 | 522 | raise RuntimeError("Module is an unknown format or the ir requested is unknown")
|
407 | 523 |
|
408 | 524 |
|
| 525 | +def load_cross_compiled_exported_program(file_path: str = "") -> Any: |
| 526 | + """ |
| 527 | + Load an ExportedProgram file in Windows which was previously cross compiled in Linux |
| 528 | +
|
| 529 | + Arguments: |
| 530 | + file_path (str): Path to file on the disk |
| 531 | +
|
| 532 | + Raises: |
| 533 | + ValueError: If the api is not called in windows or there is no file or the file is not a valid ExportedProgram file |
| 534 | + """ |
| 535 | + return dynamo_load_cross_compiled_exported_program(file_path) |
| 536 | + |
| 537 | + |
409 | 538 | def load(file_path: str = "") -> Any:
|
410 | 539 | """
|
411 | 540 | Load either a Torchscript model or ExportedProgram.
|
|
0 commit comments