Skip to content

Commit 74187d8

Browse files
Yinghai LuWei Wei
authored andcommitted
Make mts benchmark profiling easier (#9)
Summary: Pull Request resolved: https://github.com/pytorch/fx2trt/pull/9 By default, now we can just profile the lowered trt module and print its per layer cost in sorted form. Also renamed `inline_cvr_7x_gpu_benchmark.py` to `mts_gpu_benchmark.py` as it become quite generic now. Reviewed By: wushirong Differential Revision: D34690159 fbshipit-source-id: f30c18d2e139d934392fd7e253fd774f36a8ca11
1 parent 43bd93e commit 74187d8

File tree

4 files changed

+21
-9
lines changed

4 files changed

+21
-9
lines changed

fx/lower.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
# >>> # print_module_and_input will be called right after the fuse passes
5454
# >>> lower(module, sample_input)
5555

56-
# Observer for the model after the fuse passes.
56+
# Observer for the model after the fuse passes.
5757
FUSE_PASSES_POST_OBSERVER: Observer[
5858
Callable[[nn.Module, Input], None]
5959
] = Observer("FUSE_PASSES_POST_OBSERVER")
@@ -66,7 +66,7 @@
6666
# Observer for the TRT split submodules after lowering
6767
LOWER_SPLIT_POST_OBSERVER: Observer[
6868
Callable[[str, nn.Module, Input], None]
69-
] = Observer("LOWER_SPLIT_PRE_OBSERVER")
69+
] = Observer("LOWER_SPLIT_POST_OBSERVER")
7070
# ----------------------------------------------------------------------
7171

7272

fx/tools/engine_layer_visualize.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
2-
31
import argparse
42
import re
53
from typing import NamedTuple, List, Optional, Dict, Any, Tuple
@@ -17,9 +15,6 @@
1715
1816
Usage:
1917
python fx2trt_oss.fx/tools/engine_layer_visualize.py --log_file aaa --profile_file bbb
20-
21-
Usage(Facebook):
22-
buck run //caffe2:trt_engine_layer_visualize -- --log_file aaa --profile_file bbb
2318
"""
2419

2520

fx/tools/trt_profiler_sorted.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import tensorrt as trt
2+
import operator
3+
4+
5+
class SortedTRTProfiler(trt.IProfiler):
6+
def __init__(self):
7+
super().__init__()
8+
self.layers = {}
9+
10+
def report_layer_time(self, layer_name: str, ms: int) -> None:
11+
self.layers[layer_name] = ms
12+
13+
def print_sorted_profile(self) -> None:
14+
for k, v in sorted(self.layers.items(), key=operator.itemgetter(1)):
15+
print(f"{k}: {v}ms")
16+
17+

fx/trt_module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,15 +198,15 @@ def forward(self, *inputs):
198198

199199
return tuple(outputs)
200200

201-
def enable_profiling(self):
201+
def enable_profiling(self, profiler: "trt.IProfiler"=None):
202202
"""
203203
Enable TensorRT profiling. After calling this function, TensorRT will report
204204
time spent on each layer in stdout for each forward run.
205205
"""
206206
self._check_initialized()
207207

208208
if not self.context.profiler:
209-
self.context.profiler = trt.Profiler()
209+
self.context.profiler = trt.Profiler() if profiler is None else profiler
210210

211211
def disable_profiling(self):
212212
"""

0 commit comments

Comments
 (0)