|
4 | 4 | import logging
|
5 | 5 | from typing import Any, List, Optional, Sequence, Set, Tuple, Union
|
6 | 6 |
|
| 7 | +import tensorrt as trt |
7 | 8 | import torch
|
8 | 9 | import torch_tensorrt
|
9 | 10 | from torch.export import ExportedProgram
|
| 11 | +from torch_tensorrt import _enums |
10 | 12 | from torch_tensorrt._Device import Device
|
11 | 13 | from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
|
12 | 14 | EngineCapability,
|
|
34 | 36 | convert_module,
|
35 | 37 | repair_long_or_double_inputs,
|
36 | 38 | )
|
| 39 | +from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( |
| 40 | + TRTInterpreter, |
| 41 | + TRTInterpreterResult, |
| 42 | +) |
37 | 43 | from torch_tensorrt.dynamo.lowering import apply_lowering_passes
|
38 | 44 | from torch_tensorrt.dynamo.utils import (
|
39 | 45 | get_torch_inputs,
|
@@ -319,3 +325,122 @@ def compile_module(
|
319 | 325 | settings.use_fast_partitioner = True
|
320 | 326 |
|
321 | 327 | return partitioned_module
|
| 328 | + |
| 329 | + |
| 330 | +def interpreter( |
| 331 | + module: torch.fx.GraphModule, |
| 332 | + inputs: Sequence[Input], |
| 333 | + settings: CompilationSettings = CompilationSettings(), |
| 334 | + name: str = "", |
| 335 | +) -> TRTInterpreterResult: |
| 336 | + torch_inputs = get_torch_inputs(inputs, settings.device) |
| 337 | + module_outputs = module(*torch_inputs) |
| 338 | + |
| 339 | + if not isinstance(module_outputs, (list, tuple)): |
| 340 | + module_outputs = [module_outputs] |
| 341 | + |
| 342 | + # Int64 outputs can sometimes be generated from within other operators |
| 343 | + # such as aten.sum - such outputs can be truncated |
| 344 | + output_dtypes = [] |
| 345 | + for output in module_outputs: |
| 346 | + if settings.truncate_long_and_double and output.dtype == torch.float64: |
| 347 | + output_dtypes.append(torch.float32) |
| 348 | + elif settings.truncate_long_and_double and output.dtype == torch.int64: |
| 349 | + output_dtypes.append(torch.int32) |
| 350 | + else: |
| 351 | + output_dtypes.append(output.dtype) |
| 352 | + |
| 353 | + interpreter = TRTInterpreter( |
| 354 | + module, |
| 355 | + inputs, |
| 356 | + logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), |
| 357 | + output_dtypes=output_dtypes, |
| 358 | + compilation_settings=settings, |
| 359 | + ) |
| 360 | + interpreter_result = interpreter.run( |
| 361 | + workspace_size=settings.workspace_size, |
| 362 | + precision=settings.precision, |
| 363 | + profiling_verbosity=( |
| 364 | + trt.ProfilingVerbosity.VERBOSE |
| 365 | + if settings.debug |
| 366 | + else trt.ProfilingVerbosity.LAYER_NAMES_ONLY |
| 367 | + ), |
| 368 | + max_aux_streams=settings.max_aux_streams, |
| 369 | + version_compatible=settings.version_compatible, |
| 370 | + optimization_level=settings.optimization_level, |
| 371 | + ) |
| 372 | + return interpreter_result |
| 373 | + |
| 374 | + |
| 375 | +def convert_method_to_trt_engine( |
| 376 | + module: torch.fx.GraphModule, |
| 377 | + method_name: str = "forward", |
| 378 | + inputs: Optional[Sequence[Input | torch.Tensor]] = None, |
| 379 | + device: Device = Device._current_device(), |
| 380 | + disable_tf32: bool = False, |
| 381 | + sparse_weights: bool = False, |
| 382 | + enabled_precisions: Optional[Set[torch.dtype | _enums.dtype]] = None, |
| 383 | + refit: bool = False, |
| 384 | + debug: bool = False, |
| 385 | + capability: _enums.EngineCapability = _enums.EngineCapability.default, |
| 386 | + num_avg_timing_iters: int = 1, |
| 387 | + workspace_size: int = 0, |
| 388 | + dla_sram_size: int = 1048576, |
| 389 | + dla_local_dram_size: int = 1073741824, |
| 390 | + dla_global_dram_size: int = 536870912, |
| 391 | + truncate_long_and_double: int = False, |
| 392 | + calibrator: object = None, |
| 393 | + allow_shape_tensors: bool = False, |
| 394 | +) -> bytes: |
| 395 | + if debug: |
| 396 | + set_log_level(logger.parent, logging.DEBUG) |
| 397 | + |
| 398 | + input_list = list(inputs) if inputs is not None else [] |
| 399 | + # Prepare torch_trt inputs |
| 400 | + input_list = prepare_inputs(input_list) |
| 401 | + device = to_torch_tensorrt_device(device) |
| 402 | + |
| 403 | + enabled_precisions = ( |
| 404 | + enabled_precisions if enabled_precisions is not None else {torch.float} |
| 405 | + ) |
| 406 | + |
| 407 | + if ( |
| 408 | + torch.float16 in enabled_precisions |
| 409 | + or torch_tensorrt.dtype.half in enabled_precisions |
| 410 | + ): |
| 411 | + precision = torch.float16 |
| 412 | + elif ( |
| 413 | + torch.float32 in enabled_precisions |
| 414 | + or torch_tensorrt.dtype.float in enabled_precisions |
| 415 | + ): |
| 416 | + precision = torch.float32 |
| 417 | + elif len(enabled_precisions) == 0: |
| 418 | + logger.info(f"No precision specified, defaulting to {PRECISION}") |
| 419 | + precision = PRECISION |
| 420 | + else: |
| 421 | + raise ValueError( |
| 422 | + f"Precision {enabled_precisions} not supported in the Dynamo Path" |
| 423 | + ) |
| 424 | + |
| 425 | + compilation_options = { |
| 426 | + "precision": precision, |
| 427 | + "debug": debug, |
| 428 | + "device": device, |
| 429 | + "workspace_size": workspace_size, |
| 430 | + "truncate_long_and_double": truncate_long_and_double, |
| 431 | + "max_aux_streams": MAX_AUX_STREAMS, |
| 432 | + "version_compatible": VERSION_COMPATIBLE, |
| 433 | + "optimization_level": OPTIMIZATION_LEVEL, |
| 434 | + } |
| 435 | + |
| 436 | + settings = CompilationSettings(**compilation_options) |
| 437 | + logger.info("Compilation Settings: %s\n", settings) |
| 438 | + interpreter_result = interpreter(module, input_list, settings, method_name) |
| 439 | + |
| 440 | + import io |
| 441 | + |
| 442 | + with io.BytesIO() as engine_bytes: |
| 443 | + engine_bytes.write(interpreter_result.engine.serialize()) |
| 444 | + engine_bytearray = engine_bytes.getvalue() |
| 445 | + |
| 446 | + return engine_bytearray |
0 commit comments