|
4 | 4 | import logging
|
5 | 5 | from typing import Any, Collection, List, Optional, Sequence, Set, Tuple, Union
|
6 | 6 |
|
| 7 | +import tensorrt as trt |
7 | 8 | import torch
|
8 | 9 | from torch.export import ExportedProgram
|
9 | 10 | from torch.fx.node import Target
|
| 11 | +from torch_tensorrt import _enums |
10 | 12 | from torch_tensorrt._Device import Device
|
11 | 13 | from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
|
12 | 14 | EngineCapability,
|
@@ -443,3 +445,122 @@ def compile_module(
|
443 | 445 | dryrun_stats_display(dryrun_tracker, settings.dryrun)
|
444 | 446 |
|
445 | 447 | return partitioned_module
|
| 448 | + |
| 449 | + |
| 450 | +def interpreter( |
| 451 | + module: torch.fx.GraphModule, |
| 452 | + inputs: Sequence[Input], |
| 453 | + settings: CompilationSettings = CompilationSettings(), |
| 454 | + name: str = "", |
| 455 | +) -> TRTInterpreterResult: |
| 456 | + torch_inputs = get_torch_inputs(inputs, settings.device) |
| 457 | + module_outputs = module(*torch_inputs) |
| 458 | + |
| 459 | + if not isinstance(module_outputs, (list, tuple)): |
| 460 | + module_outputs = [module_outputs] |
| 461 | + |
| 462 | + # Int64 outputs can sometimes be generated from within other operators |
| 463 | + # such as aten.sum - such outputs can be truncated |
| 464 | + output_dtypes = [] |
| 465 | + for output in module_outputs: |
| 466 | + if settings.truncate_long_and_double and output.dtype == torch.float64: |
| 467 | + output_dtypes.append(torch.float32) |
| 468 | + elif settings.truncate_long_and_double and output.dtype == torch.int64: |
| 469 | + output_dtypes.append(torch.int32) |
| 470 | + else: |
| 471 | + output_dtypes.append(output.dtype) |
| 472 | + |
| 473 | + interpreter = TRTInterpreter( |
| 474 | + module, |
| 475 | + inputs, |
| 476 | + logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), |
| 477 | + output_dtypes=output_dtypes, |
| 478 | + compilation_settings=settings, |
| 479 | + ) |
| 480 | + interpreter_result = interpreter.run( |
| 481 | + workspace_size=settings.workspace_size, |
| 482 | + precision=settings.precision, |
| 483 | + profiling_verbosity=( |
| 484 | + trt.ProfilingVerbosity.VERBOSE |
| 485 | + if settings.debug |
| 486 | + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY |
| 487 | + ), |
| 488 | + max_aux_streams=settings.max_aux_streams, |
| 489 | + version_compatible=settings.version_compatible, |
| 490 | + optimization_level=settings.optimization_level, |
| 491 | + ) |
| 492 | + return interpreter_result |
| 493 | + |
| 494 | + |
| 495 | +def convert_method_to_trt_engine( |
| 496 | + module: torch.fx.GraphModule, |
| 497 | + method_name: str = "forward", |
| 498 | + inputs: Optional[Sequence[Input | torch.Tensor]] = None, |
| 499 | + device: Device = Device._current_device(), |
| 500 | + disable_tf32: bool = False, |
| 501 | + sparse_weights: bool = False, |
| 502 | + enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, |
| 503 | + refit: bool = False, |
| 504 | + debug: bool = False, |
| 505 | + capability: _enums.EngineCapability = _enums.EngineCapability.default, |
| 506 | + num_avg_timing_iters: int = 1, |
| 507 | + workspace_size: int = 0, |
| 508 | + dla_sram_size: int = 1048576, |
| 509 | + dla_local_dram_size: int = 1073741824, |
| 510 | + dla_global_dram_size: int = 536870912, |
| 511 | + truncate_long_and_double: int = False, |
| 512 | + calibrator: object = None, |
| 513 | + allow_shape_tensors: bool = False, |
| 514 | +) -> bytes: |
| 515 | + if debug: |
| 516 | + set_log_level(logger.parent, logging.DEBUG) |
| 517 | + |
| 518 | + input_list = list(inputs) if inputs is not None else [] |
| 519 | + # Prepare torch_trt inputs |
| 520 | + input_list = prepare_inputs(input_list) |
| 521 | + device = to_torch_tensorrt_device(device) |
| 522 | + |
| 523 | + enabled_precisions = ( |
| 524 | + enabled_precisions if enabled_precisions is not None else {torch.float} |
| 525 | + ) |
| 526 | + |
| 527 | + if ( |
| 528 | + torch.float16 in enabled_precisions |
| 529 | + or torch_tensorrt.dtype.half in enabled_precisions |
| 530 | + ): |
| 531 | + precision = torch.float16 |
| 532 | + elif ( |
| 533 | + torch.float32 in enabled_precisions |
| 534 | + or torch_tensorrt.dtype.float in enabled_precisions |
| 535 | + ): |
| 536 | + precision = torch.float32 |
| 537 | + elif len(enabled_precisions) == 0: |
| 538 | + logger.info(f"No precision specified, defaulting to {PRECISION}") |
| 539 | + precision = PRECISION |
| 540 | + else: |
| 541 | + raise ValueError( |
| 542 | + f"Precision {enabled_precisions} not supported in the Dynamo Path" |
| 543 | + ) |
| 544 | + |
| 545 | + compilation_options = { |
| 546 | + "precision": precision, |
| 547 | + "debug": debug, |
| 548 | + "device": device, |
| 549 | + "workspace_size": workspace_size, |
| 550 | + "truncate_long_and_double": truncate_long_and_double, |
| 551 | + "max_aux_streams": MAX_AUX_STREAMS, |
| 552 | + "version_compatible": VERSION_COMPATIBLE, |
| 553 | + "optimization_level": OPTIMIZATION_LEVEL, |
| 554 | + } |
| 555 | + |
| 556 | + settings = CompilationSettings(**compilation_options) |
| 557 | + logger.info("Compilation Settings: %s\n", settings) |
| 558 | + interpreter_result = interpreter(module, input_list, settings, method_name) |
| 559 | + |
| 560 | + import io |
| 561 | + |
| 562 | + with io.BytesIO() as engine_bytes: |
| 563 | + engine_bytes.write(interpreter_result.engine.serialize()) |
| 564 | + engine_bytearray = engine_bytes.getvalue() |
| 565 | + |
| 566 | + return engine_bytearray |
0 commit comments