8
8
9
9
#include " xpti/xpti_trace_framework.h"
10
10
11
+ #include < dlfcn.h>
12
+ #include < iostream>
13
+ #include < mutex>
14
+ #include < string>
11
15
#include < sycl/detail/spinlock.hpp>
12
16
13
17
sycl::detail::SpinLock GlobalLock;
14
18
15
19
bool HasZEPrinter = false ;
16
- bool HasCUPrinter = false ;
17
- bool HasPIPrinter = false ;
18
- bool HasSYCLPrinter = false ;
19
20
20
- void zePrintersInit ();
21
- void zePrintersFinish ();
21
+ std::string getCurrentDSODir () {
22
+ auto CurrentFunc = reinterpret_cast <const void *>(&getCurrentDSODir);
23
+ Dl_info Info;
24
+ int RetCode = dladdr (CurrentFunc, &Info);
25
+ if (0 == RetCode) {
26
+ // This actually indicates an error
27
+ return " " ;
28
+ }
29
+
30
+ auto Path = std::string (Info.dli_fname );
31
+ auto LastSlashPos = Path.find_last_of (' /' );
32
+
33
+ return Path.substr (0 , LastSlashPos);
34
+ }
35
+
36
+ class CollectorLibraryWrapper {
37
+ typedef void (*InitFuncType)();
38
+ typedef void (*FinishFuncType)();
39
+ typedef void (*CallbackFuncType)(uint16_t , xpti::trace_event_data_t *,
40
+ xpti::trace_event_data_t *, uint64_t ,
41
+ const void *);
42
+ typedef void (*SetIndentLvlFuncType)(int );
43
+
44
+ public:
45
+ CollectorLibraryWrapper (const std::string &LibraryName)
46
+ : MLibraryName(LibraryName){};
47
+ ~CollectorLibraryWrapper () { clear (); };
48
+
49
+ const std::string InitFuncName = " init" ;
50
+ const std::string FinishFuncName = " finish" ;
51
+ const std::string CallbackFuncName = " callback" ;
52
+ const std::string IndentFuncName = " setIndentationLevel" ;
53
+
54
+ bool initPrinters () {
55
+ std::string Path = getCurrentDSODir ();
56
+ if (Path.empty ())
57
+ return false ;
58
+ Path += " /" + MLibraryName;
59
+ MHandle = dlopen (Path.c_str (), RTLD_LAZY);
60
+ if (!MHandle) {
61
+ std::cerr << " Cannot load library: " << dlerror () << ' \n ' ;
62
+ return false ;
63
+ }
64
+ auto ExportSymbol = [&](void *&FuncPtr, const std::string &FuncName) {
65
+ FuncPtr = dlsym (MHandle, FuncName.c_str ());
66
+ if (!FuncPtr) {
67
+ std::cerr << " Cannot export symbol: " << dlerror () << ' \n ' ;
68
+ return false ;
69
+ }
70
+ return true ;
71
+ };
72
+ if (!ExportSymbol (MInitPtr, InitFuncName) ||
73
+ !ExportSymbol (MFinishPtr, FinishFuncName) ||
74
+ !ExportSymbol (MSetIndentationLevelPtr, IndentFuncName) ||
75
+ !ExportSymbol (MCallbackPtr, CallbackFuncName)) {
76
+ clear ();
77
+ return false ;
78
+ }
79
+
80
+ if (MIndentationLevel)
81
+ ((SetIndentLvlFuncType)MSetIndentationLevelPtr)(MIndentationLevel);
82
+
83
+ ((InitFuncType)MInitPtr)();
84
+
85
+ return true ;
86
+ }
87
+
88
+ void finishPrinters () {
89
+ if (MHandle)
90
+ ((FinishFuncType)MFinishPtr)();
91
+ }
92
+
93
+ void setIndentationLevel (int Level) {
94
+ MIndentationLevel = Level;
95
+ if (MHandle)
96
+ ((SetIndentLvlFuncType)MSetIndentationLevelPtr)(MIndentationLevel);
97
+ }
98
+
99
+ void callback (uint16_t TraceType, xpti::trace_event_data_t *Parent,
100
+ xpti::trace_event_data_t *Event, uint64_t Instance,
101
+ const void *UserData) {
102
+ // Not expected to be called when MHandle == NULL since we should not be
103
+ // subscribed if init failed. Although still do the check for sure.
104
+ if (MHandle)
105
+ ((CallbackFuncType)MCallbackPtr)(TraceType, Parent, Event, Instance,
106
+ UserData);
107
+ }
108
+
109
+ void clear () {
110
+ MInitPtr = nullptr ;
111
+ MFinishPtr = nullptr ;
112
+ MCallbackPtr = nullptr ;
113
+ MSetIndentationLevelPtr = nullptr ;
114
+
115
+ if (MHandle)
116
+ dlclose (MHandle);
117
+ MHandle = nullptr ;
118
+ }
119
+
120
+ private:
121
+ std::string MLibraryName;
122
+ int MIndentationLevel = 0 ;
123
+
124
+ void *MHandle = nullptr ;
125
+
126
+ void *MInitPtr = nullptr ;
127
+ void *MFinishPtr = nullptr ;
128
+ void *MCallbackPtr = nullptr ;
129
+ void *MSetIndentationLevelPtr = nullptr ;
130
+ } zeCollectorLibrary(" libze_trace_collector.so" ),
131
+ cudaCollectorLibrary (" libcuda_trace_collector.so" );
132
+
133
+ // These routing functions are needed to be able to use GlobalLock for
134
+ // dynamically loaded collectors.
135
+ XPTI_CALLBACK_API void zeCallback (uint16_t TraceType,
136
+ xpti::trace_event_data_t *Parent,
137
+ xpti::trace_event_data_t *Event,
138
+ uint64_t Instance, const void *UserData) {
139
+ std::lock_guard<sycl::detail::SpinLock> _{GlobalLock};
140
+ return zeCollectorLibrary.callback (TraceType, Parent, Event, Instance,
141
+ UserData);
142
+ }
22
143
#ifdef USE_PI_CUDA
23
- void cuPrintersInit ();
24
- void cuPrintersFinish ();
144
+ XPTI_CALLBACK_API void cudaCallback (uint16_t TraceType,
145
+ xpti::trace_event_data_t *Parent,
146
+ xpti::trace_event_data_t *Event,
147
+ uint64_t Instance, const void *UserData) {
148
+ std::lock_guard<sycl::detail::SpinLock> _{GlobalLock};
149
+ return cudaCollectorLibrary.callback (TraceType, Parent, Event, Instance,
150
+ UserData);
151
+ }
25
152
#endif
153
+
26
154
void piPrintersInit ();
27
155
void piPrintersFinish ();
28
156
void syclPrintersInit ();
@@ -32,20 +160,11 @@ XPTI_CALLBACK_API void piCallback(uint16_t TraceType,
32
160
xpti::trace_event_data_t *Parent,
33
161
xpti::trace_event_data_t *Event,
34
162
uint64_t Instance, const void *UserData);
35
- XPTI_CALLBACK_API void zeCallback (uint16_t TraceType,
36
- xpti::trace_event_data_t *Parent,
37
- xpti::trace_event_data_t *Event,
38
- uint64_t Instance, const void *UserData);
39
- #ifdef USE_PI_CUDA
40
- XPTI_CALLBACK_API void cuCallback (uint16_t TraceType,
41
- xpti::trace_event_data_t *Parent,
42
- xpti::trace_event_data_t *Event,
43
- uint64_t Instance, const void *UserData);
44
- #endif
45
163
XPTI_CALLBACK_API void syclCallback (uint16_t TraceType,
46
164
xpti::trace_event_data_t *Parent,
47
165
xpti::trace_event_data_t *Event,
48
166
uint64_t Instance, const void *UserData);
167
+
49
168
XPTI_CALLBACK_API void xptiTraceInit (unsigned int /* major_version*/ ,
50
169
unsigned int /* minor_version*/ ,
51
170
const char * /* version_str*/ ,
@@ -58,30 +177,34 @@ XPTI_CALLBACK_API void xptiTraceInit(unsigned int /*major_version*/,
58
177
piCallback);
59
178
xptiRegisterCallback (StreamID, xpti::trace_function_with_args_end,
60
179
piCallback);
180
+ zeCollectorLibrary.setIndentationLevel (1 );
181
+ cudaCollectorLibrary.setIndentationLevel (1 );
61
182
#ifdef SYCL_HAS_LEVEL_ZERO
62
183
} else if (std::string_view (StreamName) ==
63
184
" sycl.experimental.level_zero.debug" &&
64
185
std::getenv (" SYCL_TRACE_ZE_ENABLE" )) {
65
- zePrintersInit ();
66
- uint16_t StreamID = xptiRegisterStream (StreamName);
67
- xptiRegisterCallback (StreamID, xpti::trace_function_with_args_begin,
68
- zeCallback);
69
- xptiRegisterCallback (StreamID, xpti::trace_function_with_args_end,
70
- zeCallback);
186
+ if (zeCollectorLibrary.initPrinters ()) {
187
+ HasZEPrinter = true ;
188
+ uint16_t StreamID = xptiRegisterStream (StreamName);
189
+ xptiRegisterCallback (StreamID, xpti::trace_function_with_args_begin,
190
+ zeCallback);
191
+ xptiRegisterCallback (StreamID, xpti::trace_function_with_args_end,
192
+ zeCallback);
193
+ }
71
194
#endif
72
195
#ifdef USE_PI_CUDA
73
196
} else if (std::string_view (StreamName) == " sycl.experimental.cuda.debug" &&
74
197
std::getenv (" SYCL_TRACE_CU_ENABLE" )) {
75
- cuPrintersInit ();
76
- uint16_t StreamID = xptiRegisterStream (StreamName);
77
- xptiRegisterCallback (StreamID, xpti::trace_function_with_args_begin,
78
- cuCallback);
79
- xptiRegisterCallback (StreamID, xpti::trace_function_with_args_end,
80
- cuCallback);
198
+ if (cudaCollectorLibrary.initPrinters ()) {
199
+ uint16_t StreamID = xptiRegisterStream (StreamName);
200
+ xptiRegisterCallback (StreamID, xpti::trace_function_with_args_begin,
201
+ cudaCallback);
202
+ xptiRegisterCallback (StreamID, xpti::trace_function_with_args_end,
203
+ cudaCallback);
204
+ }
81
205
#endif
82
- }
83
- if (std::string_view (StreamName) == " sycl" &&
84
- std::getenv (" SYCL_TRACE_API_ENABLE" )) {
206
+ } else if (std::string_view (StreamName) == " sycl" &&
207
+ std::getenv (" SYCL_TRACE_API_ENABLE" )) {
85
208
syclPrintersInit ();
86
209
uint16_t StreamID = xptiRegisterStream (StreamName);
87
210
xptiRegisterCallback (StreamID, xpti::trace_diagnostics, syclCallback);
@@ -95,13 +218,17 @@ XPTI_CALLBACK_API void xptiTraceFinish(const char *StreamName) {
95
218
#ifdef SYCL_HAS_LEVEL_ZERO
96
219
else if (std::string_view (StreamName) ==
97
220
" sycl.experimental.level_zero.debug" &&
98
- std::getenv (" SYCL_TRACE_ZE_ENABLE" ))
99
- zePrintersFinish ();
221
+ std::getenv (" SYCL_TRACE_ZE_ENABLE" )) {
222
+ zeCollectorLibrary.finishPrinters ();
223
+ zeCollectorLibrary.clear ();
224
+ }
100
225
#endif
101
226
#ifdef USE_PI_CUDA
102
227
else if (std::string_view (StreamName) == " sycl.experimental.cuda.debug" &&
103
- std::getenv (" SYCL_TRACE_CU_ENABLE" ))
104
- cuPrintersFinish ();
228
+ std::getenv (" SYCL_TRACE_CU_ENABLE" )) {
229
+ cudaCollectorLibrary.finishPrinters ();
230
+ cudaCollectorLibrary.clear ();
231
+ }
105
232
#endif
106
233
if (std::string_view (StreamName) == " sycl" &&
107
234
std::getenv (" SYCL_TRACE_API_ENABLE" ))
0 commit comments