4
4
import logging
5
5
from typing import Any , List , Optional , Sequence , Set , Tuple , Union
6
6
7
- import tensorrt as trt
8
7
import torch
9
8
import torch_tensorrt
10
9
from torch .export import ExportedProgram
33
32
)
34
33
from torch_tensorrt .dynamo .conversion import (
35
34
CompilationSettings ,
35
+ UnsupportedOperatorException ,
36
36
convert_module ,
37
+ interpret_module ,
37
38
repair_long_or_double_inputs ,
38
39
)
39
- from torch_tensorrt .dynamo .conversion ._TRTInterpreter import (
40
- TRTInterpreter ,
41
- TRTInterpreterResult ,
42
- )
43
40
from torch_tensorrt .dynamo .lowering import apply_lowering_passes
44
41
from torch_tensorrt .dynamo .utils import (
45
42
get_torch_inputs ,
@@ -327,51 +324,6 @@ def compile_module(
327
324
return partitioned_module
328
325
329
326
330
- def interpreter (
331
- module : torch .fx .GraphModule ,
332
- inputs : Sequence [Input ],
333
- settings : CompilationSettings = CompilationSettings (),
334
- name : str = "" ,
335
- ) -> TRTInterpreterResult :
336
- torch_inputs = get_torch_inputs (inputs , settings .device )
337
- module_outputs = module (* torch_inputs )
338
-
339
- if not isinstance (module_outputs , (list , tuple )):
340
- module_outputs = [module_outputs ]
341
-
342
- # Int64 outputs can sometimes be generated from within other operators
343
- # such as aten.sum - such outputs can be truncated
344
- output_dtypes = []
345
- for output in module_outputs :
346
- if settings .truncate_long_and_double and output .dtype == torch .float64 :
347
- output_dtypes .append (torch .float32 )
348
- elif settings .truncate_long_and_double and output .dtype == torch .int64 :
349
- output_dtypes .append (torch .int32 )
350
- else :
351
- output_dtypes .append (output .dtype )
352
-
353
- interpreter = TRTInterpreter (
354
- module ,
355
- inputs ,
356
- logger_level = (trt .Logger .VERBOSE if settings .debug else trt .Logger .WARNING ),
357
- output_dtypes = output_dtypes ,
358
- compilation_settings = settings ,
359
- )
360
- interpreter_result = interpreter .run (
361
- workspace_size = settings .workspace_size ,
362
- precision = settings .precision ,
363
- profiling_verbosity = (
364
- trt .ProfilingVerbosity .VERBOSE
365
- if settings .debug
366
- else trt .ProfilingVerbosity .LAYER_NAMES_ONLY
367
- ),
368
- max_aux_streams = settings .max_aux_streams ,
369
- version_compatible = settings .version_compatible ,
370
- optimization_level = settings .optimization_level ,
371
- )
372
- return interpreter_result
373
-
374
-
375
327
def convert_method_to_trt_engine (
376
328
module : torch .fx .GraphModule ,
377
329
method_name : str = "forward" ,
@@ -391,6 +343,9 @@ def convert_method_to_trt_engine(
391
343
truncate_long_and_double : int = False ,
392
344
calibrator : object = None ,
393
345
allow_shape_tensors : bool = False ,
346
+ max_aux_streams : Optional [int ] = MAX_AUX_STREAMS ,
347
+ version_compatible : bool = VERSION_COMPATIBLE ,
348
+ optimization_level : Optional [int ] = OPTIMIZATION_LEVEL ,
394
349
) -> bytes :
395
350
if debug :
396
351
set_log_level (logger .parent , logging .DEBUG )
@@ -428,15 +383,20 @@ def convert_method_to_trt_engine(
428
383
"device" : device ,
429
384
"workspace_size" : workspace_size ,
430
385
"truncate_long_and_double" : truncate_long_and_double ,
431
- "max_aux_streams" : MAX_AUX_STREAMS ,
432
- "version_compatible" : VERSION_COMPATIBLE ,
433
- "optimization_level" : OPTIMIZATION_LEVEL ,
386
+ "max_aux_streams" : max_aux_streams ,
387
+ "version_compatible" : version_compatible ,
388
+ "optimization_level" : optimization_level ,
434
389
}
435
390
436
391
settings = CompilationSettings (** compilation_options )
437
392
logger .info ("Compilation Settings: %s\n " , settings )
438
- interpreter_result = interpreter (module , input_list , settings , method_name )
439
-
393
+ try :
394
+ interpreter_result = interpret_module (module , input_list , settings , method_name )
395
+ except UnsupportedOperatorException :
396
+ logger .error (
397
+ f"Conversion of module { module } not currently fully supported or convertible!" ,
398
+ exc_info = True ,
399
+ )
440
400
import io
441
401
442
402
with io .BytesIO () as engine_bytes :
0 commit comments