4
4
import logging
5
5
from typing import Any , List , Optional , Sequence , Set , Tuple , Union
6
6
7
- import tensorrt as trt
8
7
import torch
9
8
import torch_tensorrt
10
9
from torch .export import ExportedProgram
41
40
)
42
41
from torch_tensorrt .dynamo .conversion import (
43
42
CompilationSettings ,
43
+ UnsupportedOperatorException ,
44
44
convert_module ,
45
+ interpret_module ,
45
46
repair_long_or_double_inputs ,
46
47
)
47
48
from torch_tensorrt .dynamo .lowering import apply_lowering_passes , get_decompositions
@@ -346,51 +347,6 @@ def compile_module(
346
347
return partitioned_module
347
348
348
349
349
- def interpreter (
350
- module : torch .fx .GraphModule ,
351
- inputs : Sequence [Input ],
352
- settings : CompilationSettings = CompilationSettings (),
353
- name : str = "" ,
354
- ) -> TRTInterpreterResult :
355
- torch_inputs = get_torch_inputs (inputs , settings .device )
356
- module_outputs = module (* torch_inputs )
357
-
358
- if not isinstance (module_outputs , (list , tuple )):
359
- module_outputs = [module_outputs ]
360
-
361
- # Int64 outputs can sometimes be generated from within other operators
362
- # such as aten.sum - such outputs can be truncated
363
- output_dtypes = []
364
- for output in module_outputs :
365
- if settings .truncate_long_and_double and output .dtype == torch .float64 :
366
- output_dtypes .append (torch .float32 )
367
- elif settings .truncate_long_and_double and output .dtype == torch .int64 :
368
- output_dtypes .append (torch .int32 )
369
- else :
370
- output_dtypes .append (output .dtype )
371
-
372
- interpreter = TRTInterpreter (
373
- module ,
374
- inputs ,
375
- logger_level = (trt .Logger .VERBOSE if settings .debug else trt .Logger .WARNING ),
376
- output_dtypes = output_dtypes ,
377
- compilation_settings = settings ,
378
- )
379
- interpreter_result = interpreter .run (
380
- workspace_size = settings .workspace_size ,
381
- precision = settings .precision ,
382
- profiling_verbosity = (
383
- trt .ProfilingVerbosity .VERBOSE
384
- if settings .debug
385
- else trt .ProfilingVerbosity .LAYER_NAMES_ONLY
386
- ),
387
- max_aux_streams = settings .max_aux_streams ,
388
- version_compatible = settings .version_compatible ,
389
- optimization_level = settings .optimization_level ,
390
- )
391
- return interpreter_result
392
-
393
-
394
350
def convert_method_to_trt_engine (
395
351
module : torch .fx .GraphModule ,
396
352
method_name : str = "forward" ,
@@ -410,6 +366,9 @@ def convert_method_to_trt_engine(
410
366
truncate_long_and_double : int = False ,
411
367
calibrator : object = None ,
412
368
allow_shape_tensors : bool = False ,
369
+ max_aux_streams : Optional [int ] = MAX_AUX_STREAMS ,
370
+ version_compatible : bool = VERSION_COMPATIBLE ,
371
+ optimization_level : Optional [int ] = OPTIMIZATION_LEVEL ,
413
372
) -> bytes :
414
373
if debug :
415
374
set_log_level (logger .parent , logging .DEBUG )
@@ -447,15 +406,20 @@ def convert_method_to_trt_engine(
447
406
"device" : device ,
448
407
"workspace_size" : workspace_size ,
449
408
"truncate_long_and_double" : truncate_long_and_double ,
450
- "max_aux_streams" : MAX_AUX_STREAMS ,
451
- "version_compatible" : VERSION_COMPATIBLE ,
452
- "optimization_level" : OPTIMIZATION_LEVEL ,
409
+ "max_aux_streams" : max_aux_streams ,
410
+ "version_compatible" : version_compatible ,
411
+ "optimization_level" : optimization_level ,
453
412
}
454
413
455
414
settings = CompilationSettings (** compilation_options )
456
415
logger .info ("Compilation Settings: %s\n " , settings )
457
- interpreter_result = interpreter (module , input_list , settings , method_name )
458
-
416
+ try :
417
+ interpreter_result = interpret_module (module , input_list , settings , method_name )
418
+ except UnsupportedOperatorException :
419
+ logger .error (
420
+ f"Conversion of module { module } not currently fully supported or convertible!" ,
421
+ exc_info = True ,
422
+ )
459
423
import io
460
424
461
425
with io .BytesIO () as engine_bytes :
0 commit comments