Skip to content

Commit 3047783

Browse files
Yinghai LuWei Wei
authored andcommitted
Print shape info in trt profiler (#12)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/12 Now we can provide shape info for per layer profile, sorted, which will be quite convenient in terms of flushing out lower hanging fruit. Note that in order to get shape info, we need to turn the profile verbose model to true, which triggers `trt.ProfilingVerbosity.DETAILED` mode. Reviewed By: jasonjk-park, 842974287 Differential Revision: D34712362 fbshipit-source-id: 82b94ca939a54ff0e1340789da80449915fd0b0e
1 parent ced1978 commit 3047783

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

fx/lower.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import torch.fx as fx
1111
import torch.nn as nn
12+
from fx2trt_oss.fx.observer import Observer
1213
from fx2trt_oss.tracer.acc_tracer import acc_ops
1314
from torch.fx.experimental.const_fold import split_const_subgraphs
1415
from torch.fx.passes.splitter_base import SplitResult
@@ -34,7 +35,6 @@
3435
from .trt_module import (
3536
TRTModule,
3637
)
37-
from fx2trt_oss.fx.observer import Observer
3838

3939

4040
logger = logging.getLogger(__name__)
@@ -182,6 +182,8 @@ class LowerSetting:
182182
modules will not be traced into.
183183
184184
cuda_graph_batch_size (int): Cuda graph batch size, default to be -1.
185+
186+
verbose_profile (bool): verbosity of profiler, default to False
185187
"""
186188
max_batch_size: int = 2048
187189
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
@@ -200,6 +202,7 @@ class LowerSetting:
200202
ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None
201203
leaf_module_list: Optional[Set[Type[nn.Module]]] = None
202204
cuda_graph_batch_size: int = -1
205+
verbose_profile: bool = False
203206

204207

205208
def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule:
@@ -283,6 +286,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
283286
strict_type_constraints=self.lower_setting.strict_type_constraints,
284287
algorithm_selector=algo_selector,
285288
timing_cache=cache_data,
289+
profiling_verbosity=trt.ProfilingVerbosity.DETAILED
290+
if self.lower_setting.verbose_profile
291+
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
286292
)
287293

288294
# Update timing cache file if needed

fx/tools/trt_profiler_sorted.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import tensorrt as trt
22
import operator
3+
from typing import Optional, Mapping, List
4+
import torch
5+
import json
6+
from fx2trt_oss.fx import TRTModule
37

48

59
class SortedTRTProfiler(trt.IProfiler):
@@ -10,8 +14,32 @@ def __init__(self):
1014
def report_layer_time(self, layer_name: str, ms: int) -> None:
1115
self.layers[layer_name] = ms
1216

13-
def print_sorted_profile(self) -> None:
17+
def print_sorted_profile(self, additional_info: Optional[Mapping[str, str]]) -> None:
18+
additional_info = {} if additional_info is None else additional_info
1419
for k, v in sorted(self.layers.items(), key=operator.itemgetter(1)):
15-
print(f"{k}: {v}ms")
20+
additional_str = additional_info.get(k, "")
21+
print(f"{k} {additional_str}: {v}ms")
1622

1723

24+
def profile_trt_module(
25+
name: str, trt_mod: TRTModule, mod_input: List[torch.Tensor]
26+
) -> None:
27+
"""
28+
Provide per layer timing and shape info
29+
"""
30+
layer_info = json.loads(trt_mod.get_layer_info()) # pyre-ignore[29]
31+
shape_map = {}
32+
for layer in layer_info["Layers"]:
33+
name = layer["Name"]
34+
input_str = ", ".join(
35+
[str(x.get("Dimensions", "[]")) for x in layer.get("Inputs", [])]
36+
)
37+
output_str = ", ".join(
38+
[str(x.get("Dimensions", "[]")) for x in layer.get("Outputs", [])]
39+
)
40+
shape_map[name] = f"({input_str}) -> ({output_str})"
41+
42+
trt_mod.enable_profiling(profiler=SortedTRTProfiler()) # pyre-ignore[29]
43+
_ = trt_mod(*mod_input)
44+
trt_mod.context.profiler.print_sorted_profile(shape_map) # pyre-ignore[16]
45+
trt_mod.disable_profiling() # pyre-ignore[29]

fx/trt_module.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,10 @@ def disable_profiling(self):
217217
torch.cuda.synchronize()
218218
del self.context
219219
self.context = self.engine.create_execution_context()
220+
221+
def get_layer_info(self) -> str:
222+
"""
223+
Get layer info of the engine. Only support for TRT > 8.2.
224+
"""
225+
inspector = self.engine.create_engine_inspector()
226+
return inspector.get_engine_information(trt.LayerInformationFormat.JSON)

0 commit comments

Comments
 (0)