Skip to content

Commit 4cac480

Browse files
leleamolVikas-kum
andauthored
Adding script to read the events from trace file (aws#246)
* adding script to read the events from tensorboard trace. * Refactored code Introduced base event parser and the child classes TFEventProfiler and SMTFEventProfiler Co-authored-by: Vikas Kumar <[email protected]>
1 parent 80153e7 commit 4cac480

File tree

8 files changed

+3366
-0
lines changed

8 files changed

+3366
-0
lines changed

config/tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ export REPORT_DIR=$OUT_DIR/pytest_reports
5151
python -m pytest ${code_coverage_smdebug:+--cov=./ --cov-append} -v -W=ignore --durations=50 --html=$REPORT_DIR/report_analysis.html --self-contained-html tests/analysis
5252

5353
run_for_framework core
54+
run_for_framework profiler
5455

5556
if [ "$run_pytest_xgboost" = "enable" ] ; then
5657
run_for_framework xgboost

smdebug/profiler/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Local
2+
from .tf_profiler_parser import SMTFProfilerEvents, TFProfilerEvents
3+
from .trace_event_file_parser import TraceEvent, TraceEventParser
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Standard Library
2+
import json
3+
4+
# First Party
5+
from smdebug.profiler.trace_event_file_parser import ProcessInfo, TraceEventParser
6+
7+
8+
class SMTFProfilerEvents(TraceEventParser):
9+
def __init__(self, trace_file):
10+
self._trace_json_file = trace_file
11+
super().__init__()
12+
self.read_trace_file()
13+
14+
def _populate_start_time(self, event):
15+
event_args = event["args"] if "args" in event else None
16+
if self._start_time_known is False:
17+
if event_args is None:
18+
return
19+
if "start_time_since_epoch_in_micros" in event_args:
20+
self._start_timestamp = event_args["start_time_since_epoch_in_micros"]
21+
self._start_time_known = True
22+
self.logger.info(f"Start time for events in uSeconds = {self._start_timestamp}")
23+
24+
# TODO implementation of below would be changed to support streaming file and incomplete json file
25+
def read_trace_file(self):
26+
try:
27+
with open(self._trace_json_file) as json_data:
28+
trace_json_data = json.load(json_data)
29+
except Exception as e:
30+
self.logger.error(
31+
f"Can't open TF trace file {self._trace_json_file}: Exception {str(e)}"
32+
)
33+
return
34+
35+
for event in trace_json_data:
36+
self._read_event(event)
37+
38+
39+
class TFProfilerEvents(TraceEventParser):
40+
def __init__(self, trace_file):
41+
self._trace_json_file = trace_file
42+
super().__init__()
43+
self.read_trace_file()
44+
45+
def _populate_thread_info_for_metaevent(self, event):
46+
if event["name"] == "thread_name":
47+
name = event["args"]["name"]
48+
t_id = event["tid"]
49+
pid = event["pid"]
50+
if pid not in self._processes:
51+
self.logger.warn(
52+
f"Did not find matching process for pid {pid}. Creating a process with name 'Unknown'"
53+
)
54+
self._processes[pid] = ProcessInfo(pid, "Unknown")
55+
self._processes[pid].add_thread(t_id, name)
56+
57+
def _populate_start_time(self, event):
58+
# TODO, not sure if we can implement this right now
59+
return
60+
61+
def read_trace_file(self):
62+
try:
63+
with open(self._trace_json_file) as json_data:
64+
trace_json_data = json.load(json_data)
65+
except Exception as e:
66+
self.logger.error(
67+
f"Can't open TF trace file {self._trace_json_file}: Exception {str(e)} "
68+
)
69+
return
70+
if "traceEvents" not in trace_json_data:
71+
self.logger.error(
72+
f"The TF trace file {self._trace_json_file} does not contain traceEvents"
73+
)
74+
return
75+
trace_events_json = trace_json_data["traceEvents"]
76+
77+
for event in trace_events_json:
78+
self._read_event(event)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# First Party
2+
from smdebug.core.logger import get_logger
3+
4+
5+
class ThreadInfo:
6+
def __init__(self, tid, thread_name):
7+
self.tid = tid
8+
self.thread_name = thread_name
9+
10+
11+
class ProcessInfo:
12+
def __init__(self, id, name):
13+
self.id = id
14+
self.name = name
15+
self._threads = dict()
16+
17+
def add_thread(self, threadid, thread_name):
18+
self._threads[threadid] = ThreadInfo(threadid, thread_name)
19+
20+
def get_thread_info(self, threadid):
21+
return self._threads[threadid]
22+
23+
24+
class TraceEvent:
25+
def __init__(self, ts, name, dur, pid, tid, event_args):
26+
self.start_time = ts
27+
self.event_name = name
28+
self.duration = dur
29+
self.end_time = self.start_time + self.duration
30+
self.pid = pid
31+
self.tid = tid
32+
self.event_args = event_args
33+
34+
35+
class TraceEventParser:
36+
def __init__(self):
37+
self._processes = dict()
38+
self._trace_events = list()
39+
self._start_timestamp = 0
40+
self._start_time_known = False
41+
# The timestamp in trace events are in micro seconds, we multiply by 1000 to convert to ns
42+
self._timescale_multiplier_for_ns = 1000
43+
self.logger = get_logger("smdebug-profiler")
44+
45+
def read_trace_file(self):
46+
pass
47+
48+
def _populate_process_info_for_metaevent(self, event):
49+
id = event["pid"]
50+
if event["name"] == "process_name":
51+
name = event["args"]["name"] if "name" in event["args"] else "Unknown"
52+
self._processes[id] = ProcessInfo(id, name)
53+
54+
def _populate_thread_info_for_metaevent(self, event):
55+
pass
56+
57+
def _populate_start_time(self, event):
58+
pass
59+
60+
def _read_event(self, event):
61+
if "ph" not in event:
62+
self.logger.error(f"In correctly formatted trace file. The 'ph' field is not present")
63+
return
64+
phase_type = event["ph"]
65+
if phase_type == "M":
66+
self._populate_process_info_for_metaevent(event)
67+
self._populate_thread_info_for_metaevent(event)
68+
self._populate_start_time(event)
69+
70+
if phase_type == "X":
71+
# In nano seconds
72+
start_time = (event["ts"] + self._start_timestamp) * self._timescale_multiplier_for_ns
73+
# In nano seconds
74+
dur = event["dur"] * self._timescale_multiplier_for_ns
75+
name = event["name"]
76+
id = event["pid"]
77+
tid = event["tid"] if "tid" in event else "0"
78+
event_args = event["args"] if "args" in event else None
79+
t_event = TraceEvent(start_time, name, dur, id, tid, event_args)
80+
self._trace_events.append(t_event)
81+
82+
def get_all_events(self):
83+
return self._trace_events
84+
85+
def get_events_start_time_sorted(self):
86+
return sorted(self._trace_events, key=lambda x: x.start_time)
87+
88+
def get_events_end_time_sorted(self):
89+
return sorted(self._trace_events, key=lambda x: x.end_time)
90+
91+
"""
92+
Return the events that are in progress at the specified timestamp.
93+
Performance of this function can be improved by implementing interval tree.
94+
"""
95+
96+
def get_events_at(self, timestamp_in_nanoseconds):
97+
result_events = list()
98+
for x_event in self._trace_events:
99+
if x_event.start_time <= timestamp_in_nanoseconds <= x_event.end_time:
100+
result_events.append(x_event)
101+
return result_events
102+
103+
"""
104+
Return the events that have started and completed within the given start and end time boundaries.
105+
The events that are in progress during these boundaries are not included.
106+
"""
107+
108+
def get_events_within_range(self, start_time, end_time):
109+
result_events = list()
110+
for x_event in self._trace_events:
111+
if start_time <= x_event.start_time and end_time >= x_event.end_time:
112+
result_events.append(x_event)
113+
return result_events
114+
115+
def get_process_info(self, process_id):
116+
return self._processes[process_id]
117+
118+
def get_processes(self):
119+
return self._processes
120+
121+
# TODO
122+
def get_events_for_process(self, pid, start_time, end_time):
123+
pass
124+
125+
# TODO
126+
def get_events_for_thread(self, tid, start_time, end_time):
127+
pass

0 commit comments

Comments
 (0)