1
1
import logging
2
2
import math
3
+ import operator
4
+ import os
3
5
from dataclasses import dataclass , field
4
- from typing import Any , Dict , List
6
+ from typing import Any , Dict , List , Union
5
7
8
+ import torch
6
9
from torch_tensorrt .dynamo ._settings import CompilationSettings
10
+ from torch_tensorrt .dynamo .conversion ._ConverterRegistry import ConverterRegistry
11
+ from torch_tensorrt .dynamo .conversion .converter_utils import get_node_name
7
12
8
13
logger = logging .getLogger (__name__ )
9
14
@@ -44,6 +49,7 @@ class DryRunTracker:
44
49
tensorrt_graph_count (int): Number of TensorRT engines to be generated
45
50
compilation_settings (CompilationSettings): User Compilation Settings
46
51
unsupported_ops (Dict[str, int]): Set of operators not supported in TRT
52
+ to_run_in_torch (List[str]): Set of nodes to run in Torch
47
53
"""
48
54
49
55
total_ops_in_graph : int = 0
@@ -58,9 +64,12 @@ class DryRunTracker:
58
64
default_factory = CompilationSettings
59
65
)
60
66
unsupported_ops : Dict [str , int ] = field (default_factory = dict )
67
+ to_run_in_torch : List [str ] = field (default_factory = list )
61
68
62
69
63
- def dryrun_stats_display (dryrun_tracker : DryRunTracker , dryrun_enabled : bool ) -> None :
70
+ def dryrun_stats_display (
71
+ dryrun_tracker : DryRunTracker , dryrun_enabled : Union [bool , str ]
72
+ ) -> None :
64
73
"""Displays statistics about the dryrun either to debug logs or stdout"""
65
74
formatted_stats = "\n "
66
75
@@ -71,7 +80,19 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
71
80
f"of which { dryrun_tracker .supported_ops_in_graph } operators are supported, "
72
81
f"{ round (dryrun_tracker .supported_ops_in_graph * 100 / dryrun_tracker .total_ops_in_graph , 2 )} % coverage\n \n "
73
82
)
74
- formatted_stats += f"The following ops are currently unsupported and set to run in Torch: { dryrun_tracker .unsupported_ops } \n \n "
83
+ if dryrun_tracker .unsupported_ops :
84
+ parsed_ops = "\n " .join (
85
+ [f"{ str (k )} : { str (v )} " for k , v in dryrun_tracker .unsupported_ops .items ()]
86
+ )
87
+ formatted_stats += f"The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:\n { parsed_ops } \n \n "
88
+
89
+ if dryrun_tracker .to_run_in_torch :
90
+ formatted_nodes = "\n " .join (dryrun_tracker .to_run_in_torch )
91
+ formatted_stats += (
92
+ f"The following nodes are currently set to run in Torch:\n { formatted_nodes } \n "
93
+ "Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner\n \n "
94
+ )
95
+
75
96
formatted_stats += f"Compiled with: { dryrun_tracker .compilation_settings } \n \n "
76
97
77
98
assert len (dryrun_tracker .per_subgraph_data ) == dryrun_tracker .tensorrt_graph_count
@@ -184,8 +205,17 @@ def dryrun_stats_display(dryrun_tracker: DryRunTracker, dryrun_enabled: bool) ->
184
205
)
185
206
186
207
# If user specified "dryrun=True", print to stdout, else debug
208
+ # If user specified a filepath, save the output to the path as well
187
209
if dryrun_enabled :
188
210
print (formatted_stats )
211
+ if isinstance (dryrun_enabled , str ):
212
+ if os .path .exists (dryrun_enabled ):
213
+ logger .warning (
214
+ f"File already exists at path { dryrun_enabled } , not saving dryrun output"
215
+ )
216
+ else :
217
+ with open (dryrun_enabled , "w+" ) as f :
218
+ f .write (formatted_stats )
189
219
else :
190
220
logger .debug (formatted_stats )
191
221
@@ -225,3 +255,23 @@ def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
225
255
)
226
256
227
257
return input_formatter_helper (shapes , dtypes )[:- 2 ]
258
+
259
+
260
+ def parse_non_trt_nodes (graph_module : torch .fx .GraphModule ) -> List [str ]:
261
+ """Parses call_function and call_method nodes from a GraphModule
262
+ Excludes getitem nodes
263
+
264
+ Returns a string representation of the nodes
265
+ """
266
+ to_run_in_torch = []
267
+ for node in graph_module .graph .nodes :
268
+ # getitem nodes are excluded since they are a Tensor-collection op
269
+ if (
270
+ node .op in ("call_function" , "call_method" )
271
+ and node .target != operator .getitem
272
+ ):
273
+ to_run_in_torch .append (
274
+ f"Node: { ConverterRegistry .qualified_name_or_str (node .target )} , "
275
+ f"with layer location: { get_node_name (node )} "
276
+ )
277
+ return to_run_in_torch
0 commit comments