|
2 | 2 |
|
3 | 3 | import collections.abc
|
4 | 4 | import logging
|
| 5 | +from dataclasses import field |
5 | 6 | from typing import Any, List, Optional, Sequence, Set, Tuple, Union
|
6 | 7 |
|
7 | 8 | import torch
|
|
42 | 43 | CompilationSettings,
|
43 | 44 | UnsupportedOperatorException,
|
44 | 45 | convert_module,
|
45 |
| - interpret_module, |
| 46 | + interpret_module_to_result, |
46 | 47 | repair_long_or_double_inputs,
|
47 | 48 | )
|
48 | 49 | from torch_tensorrt.dynamo.lowering import apply_lowering_passes, get_decompositions
|
@@ -351,25 +352,108 @@ def convert_method_to_trt_engine(
|
351 | 352 | module: torch.fx.GraphModule,
|
352 | 353 | method_name: str = "forward",
|
353 | 354 | inputs: Optional[Sequence[Input | torch.Tensor]] = None,
|
354 |
| - device: Device = Device._current_device(), |
355 |
| - disable_tf32: bool = False, |
356 |
| - sparse_weights: bool = False, |
357 | 355 | enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None,
|
358 |
| - refit: bool = False, |
359 |
| - debug: bool = False, |
360 |
| - capability: _enums.EngineCapability = _enums.EngineCapability.default, |
361 |
| - num_avg_timing_iters: int = 1, |
362 |
| - workspace_size: int = 0, |
363 |
| - dla_sram_size: int = 1048576, |
364 |
| - dla_local_dram_size: int = 1073741824, |
365 |
| - dla_global_dram_size: int = 536870912, |
366 |
| - truncate_long_and_double: int = False, |
367 |
| - calibrator: object = None, |
368 |
| - allow_shape_tensors: bool = False, |
| 356 | + debug: bool = DEBUG, |
| 357 | + workspace_size: int = WORKSPACE_SIZE, |
| 358 | + min_block_size: int = MIN_BLOCK_SIZE, |
| 359 | + torch_executed_ops: Set[str] = field(default_factory=set), |
| 360 | + pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, |
369 | 361 | max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
|
370 | 362 | version_compatible: bool = VERSION_COMPATIBLE,
|
371 | 363 | optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
|
| 364 | + use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, |
| 365 | + truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, |
| 366 | + use_fast_partitioner: bool = USE_FAST_PARTITIONER, |
| 367 | + enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, |
| 368 | + device: Device = Device._current_device(), |
| 369 | + require_full_compilation: bool = REQUIRE_FULL_COMPILATION, |
| 370 | + disable_tf32: bool = DISABLE_TF32, |
| 371 | + sparse_weights: bool = SPARSE_WEIGHTS, |
| 372 | + refit: bool = REFIT, |
| 373 | + engine_capability: EngineCapability = ENGINE_CAPABILITY, |
| 374 | + num_avg_timing_iters: int = NUM_AVG_TIMING_ITERS, |
| 375 | + dla_sram_size: int = DLA_SRAM_SIZE, |
| 376 | + dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE, |
| 377 | + dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE, |
| 378 | + calibrator: object = None, |
| 379 | + allow_shape_tensors: bool = False, |
372 | 380 | ) -> bytes:
|
| 381 | + """Convert a GraphModule module method to a serialized TensorRT engine |
| 382 | +
|
| 383 | + Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings |
| 384 | +
|
| 385 | + Arguments: |
| 386 | + module (torch.fx.GraphModule): Source module |
| 387 | +
|
| 388 | + Keyword Args: |
| 389 | + inputs (List[Union(torch_tensorrt.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using |
| 390 | + torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum |
| 391 | + to select device type. :: |
| 392 | +
|
| 393 | + input=[ |
| 394 | + torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 |
| 395 | + torch_tensorrt.Input( |
| 396 | + min_shape=(1, 224, 224, 3), |
| 397 | + opt_shape=(1, 512, 512, 3), |
| 398 | + max_shape=(1, 1024, 1024, 3), |
| 399 | + dtype=torch.int32 |
| 400 | + format=torch.channel_last |
| 401 | + ), # Dynamic input shape for input #2 |
| 402 | + torch.randn((1, 3, 224, 244)) # Use an example tensor and let torch_tensorrt infer settings |
| 403 | + ] |
| 404 | +
|
| 405 | + method_name (str): Name of method to convert |
| 406 | + input_signature Union(List, Tuple, torch_tensorrt.Input, torch.Tensor): A formatted collection of input specifications for the module. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using |
| 407 | + torch datatypes or torch_tensorrt datatypes and you can use either torch devices or the torch_tensorrt device type enum to select device type. **This API should be considered beta-level stable and may change in the future** :: |
| 408 | +
|
| 409 | + input_signature=([ |
| 410 | + torch_tensorrt.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 |
| 411 | + torch_tensorrt.Input( |
| 412 | + min_shape=(1, 224, 224, 3), |
| 413 | + opt_shape=(1, 512, 512, 3), |
| 414 | + max_shape=(1, 1024, 1024, 3), |
| 415 | + dtype=torch.int32 |
| 416 | + format=torch.channel_last |
| 417 | + ), # Dynamic input shape for input #2 |
| 418 | + ], torch.randn((1, 3, 224, 244))) # Use an example tensor and let torch_tensorrt infer settings for input #3 |
| 419 | +
|
| 420 | + device (Union(torch_tensorrt.Device, torch.device, dict)): Target device for TensorRT engines to run on :: |
| 421 | +
|
| 422 | + device=torch_tensorrt.Device("dla:1", allow_gpu_fallback=True) |
| 423 | +
|
| 424 | + debug (bool): Whether to print out verbose debugging information |
| 425 | + workspace_size (int): Workspace TRT is allowed to use for the module (0 is default) |
| 426 | + min_block_size (int): Minimum number of operators per TRT-Engine Block |
| 427 | + torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage |
| 428 | + pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False) |
| 429 | + max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine |
| 430 | + version_compatible (bool): Provide version forward-compatibility for engine plan files |
| 431 | + optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time, |
| 432 | + searching for more optimization options. TRT defaults to 3 |
| 433 | + use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime |
| 434 | + based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the |
| 435 | + argument as None |
| 436 | + truncate_long_and_double (bool): Whether to truncate int64/float64 TRT engine inputs or weights to int32/float32 |
| 437 | + use_fast_partitioner (bool): Whether to use the fast or global graph partitioning system |
| 438 | + enable_experimental_decompositions (bool): Whether to enable all core aten decompositions |
| 439 | + or only a selected subset of them |
| 440 | + device (Device): GPU to compile the model on |
| 441 | + require_full_compilation (bool): Whether to require the graph is fully compiled in TensorRT. |
| 442 | + Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path |
| 443 | + disable_tf32 (bool): Whether to disable TF32 computation for TRT layers |
| 444 | + sparse_weights (bool): Whether to allow the builder to use sparse weights |
| 445 | + refit (bool): Whether to build a refittable engine |
| 446 | + engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels |
| 447 | + num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels |
| 448 | + dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. |
| 449 | + dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations |
| 450 | + dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution |
| 451 | + calibrator (Union(torch_tensorrt._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration |
| 452 | + allow_shape_tensors: (Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT |
| 453 | +
|
| 454 | + Returns: |
| 455 | + bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs |
| 456 | + """ |
373 | 457 | if debug:
|
374 | 458 | set_log_level(logger.parent, logging.DEBUG)
|
375 | 459 |
|
@@ -403,18 +487,33 @@ def convert_method_to_trt_engine(
|
403 | 487 | compilation_options = {
|
404 | 488 | "precision": precision,
|
405 | 489 | "debug": debug,
|
406 |
| - "device": device, |
407 | 490 | "workspace_size": workspace_size,
|
408 |
| - "truncate_long_and_double": truncate_long_and_double, |
| 491 | + "min_block_size": min_block_size, |
| 492 | + "torch_executed_ops": torch_executed_ops, |
| 493 | + "pass_through_build_failures": pass_through_build_failures, |
409 | 494 | "max_aux_streams": max_aux_streams,
|
410 | 495 | "version_compatible": version_compatible,
|
411 | 496 | "optimization_level": optimization_level,
|
| 497 | + "use_python_runtime": use_python_runtime, |
| 498 | + "truncate_long_and_double": truncate_long_and_double, |
| 499 | + "use_fast_partitioner": use_fast_partitioner, |
| 500 | + "enable_experimental_decompositions": enable_experimental_decompositions, |
| 501 | + "device": device, |
| 502 | + "require_full_compilation": require_full_compilation, |
| 503 | + "disable_tf32": disable_tf32, |
| 504 | + "sparse_weights": sparse_weights, |
| 505 | + "refit": refit, |
| 506 | + "engine_capability": engine_capability, |
| 507 | + "num_avg_timing_iters": num_avg_timing_iters, |
| 508 | + "dla_sram_size": dla_sram_size, |
| 509 | + "dla_local_dram_size": dla_local_dram_size, |
| 510 | + "dla_global_dram_size": dla_global_dram_size, |
412 | 511 | }
|
413 | 512 |
|
414 | 513 | settings = CompilationSettings(**compilation_options)
|
415 | 514 | logger.info("Compilation Settings: %s\n", settings)
|
416 | 515 | try:
|
417 |
| - interpreter_result = interpret_module(module, input_list, settings, method_name) |
| 516 | + interpreter_result = interpret_module_to_result(module, input_list, settings) |
418 | 517 | except UnsupportedOperatorException:
|
419 | 518 | logger.error(
|
420 | 519 | f"Conversion of module {module} not currently fully supported or convertible!",
|
|
0 commit comments