4
4
import logging
5
5
from typing import Any , Collection , List , Optional , Sequence , Set , Tuple , Union
6
6
7
- import tensorrt as trt
8
7
import torch
9
8
from torch .export import ExportedProgram
10
9
from torch .fx .node import Target
49
48
)
50
49
from torch_tensorrt .dynamo .conversion import (
51
50
CompilationSettings ,
51
+ UnsupportedOperatorException ,
52
52
convert_module ,
53
+ interpret_module ,
53
54
repair_long_or_double_inputs ,
54
55
)
55
56
from torch_tensorrt .dynamo .conversion ._ConverterRegistry import (
@@ -447,51 +448,6 @@ def compile_module(
447
448
return partitioned_module
448
449
449
450
450
- def interpreter (
451
- module : torch .fx .GraphModule ,
452
- inputs : Sequence [Input ],
453
- settings : CompilationSettings = CompilationSettings (),
454
- name : str = "" ,
455
- ) -> TRTInterpreterResult :
456
- torch_inputs = get_torch_inputs (inputs , settings .device )
457
- module_outputs = module (* torch_inputs )
458
-
459
- if not isinstance (module_outputs , (list , tuple )):
460
- module_outputs = [module_outputs ]
461
-
462
- # Int64 outputs can sometimes be generated from within other operators
463
- # such as aten.sum - such outputs can be truncated
464
- output_dtypes = []
465
- for output in module_outputs :
466
- if settings .truncate_long_and_double and output .dtype == torch .float64 :
467
- output_dtypes .append (torch .float32 )
468
- elif settings .truncate_long_and_double and output .dtype == torch .int64 :
469
- output_dtypes .append (torch .int32 )
470
- else :
471
- output_dtypes .append (output .dtype )
472
-
473
- interpreter = TRTInterpreter (
474
- module ,
475
- inputs ,
476
- logger_level = (trt .Logger .VERBOSE if settings .debug else trt .Logger .WARNING ),
477
- output_dtypes = output_dtypes ,
478
- compilation_settings = settings ,
479
- )
480
- interpreter_result = interpreter .run (
481
- workspace_size = settings .workspace_size ,
482
- precision = settings .precision ,
483
- profiling_verbosity = (
484
- trt .ProfilingVerbosity .VERBOSE
485
- if settings .debug
486
- else trt .ProfilingVerbosity .LAYER_NAMES_ONLY
487
- ),
488
- max_aux_streams = settings .max_aux_streams ,
489
- version_compatible = settings .version_compatible ,
490
- optimization_level = settings .optimization_level ,
491
- )
492
- return interpreter_result
493
-
494
-
495
451
def convert_method_to_trt_engine (
496
452
module : torch .fx .GraphModule ,
497
453
method_name : str = "forward" ,
@@ -511,6 +467,9 @@ def convert_method_to_trt_engine(
511
467
truncate_long_and_double : int = False ,
512
468
calibrator : object = None ,
513
469
allow_shape_tensors : bool = False ,
470
+ max_aux_streams : Optional [int ] = MAX_AUX_STREAMS ,
471
+ version_compatible : bool = VERSION_COMPATIBLE ,
472
+ optimization_level : Optional [int ] = OPTIMIZATION_LEVEL ,
514
473
) -> bytes :
515
474
if debug :
516
475
set_log_level (logger .parent , logging .DEBUG )
@@ -548,15 +507,20 @@ def convert_method_to_trt_engine(
548
507
"device" : device ,
549
508
"workspace_size" : workspace_size ,
550
509
"truncate_long_and_double" : truncate_long_and_double ,
551
- "max_aux_streams" : MAX_AUX_STREAMS ,
552
- "version_compatible" : VERSION_COMPATIBLE ,
553
- "optimization_level" : OPTIMIZATION_LEVEL ,
510
+ "max_aux_streams" : max_aux_streams ,
511
+ "version_compatible" : version_compatible ,
512
+ "optimization_level" : optimization_level ,
554
513
}
555
514
556
515
settings = CompilationSettings (** compilation_options )
557
516
logger .info ("Compilation Settings: %s\n " , settings )
558
- interpreter_result = interpreter (module , input_list , settings , method_name )
559
-
517
+ try :
518
+ interpreter_result = interpret_module (module , input_list , settings , method_name )
519
+ except UnsupportedOperatorException :
520
+ logger .error (
521
+ f"Conversion of module { module } not currently fully supported or convertible!" ,
522
+ exc_info = True ,
523
+ )
560
524
import io
561
525
562
526
with io .BytesIO () as engine_bytes :
0 commit comments