Skip to content

Commit a7910aa

Browse files
authored
[offload][OMPT] Add device-specific tracing control (llvm#1564)
2 parents bbb4ebf + 2276160 commit a7910aa

File tree

15 files changed

+321
-155
lines changed

15 files changed

+321
-155
lines changed

offload/include/OmptCommonDefs.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
} while (0)
9393

9494
typedef ompt_set_result_t (*libomptarget_ompt_set_trace_ompt_t)(
95-
ompt_device_t *Device, unsigned int Enable, unsigned int EventTy);
95+
int Device, unsigned int Enable, unsigned int EventTy);
9696
typedef int (*libomptarget_ompt_start_trace_t)(int,
9797
ompt_callback_buffer_request_t,
9898
ompt_callback_buffer_complete_t);

offload/include/OmptTracing.h

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -76,56 +76,66 @@ void setOmptHostToDeviceRate(double Slope, double Offset);
7676
/// Set / store the number of granted teams
7777
void setOmptGrantedNumTeams(uint64_t NumTeams);
7878

79-
/// Activate / deactivate tracing
80-
void setTracingState(bool Enabled);
79+
/// Check if (1) tracing is globally active (2) the given device is actively
80+
/// traced and (3) the given event type is traced on the device
81+
bool isTracingEnabled(int DeviceId, unsigned int EventTy);
8182

82-
/// Check if the given tracing type is monitored
83-
bool isTracingTypeEnabled(unsigned int EventTy);
83+
/// Check if the given device is actively traced
84+
bool isTracedDevice(int DeviceId);
8485

85-
/// Set whether the given tracing type should be monitored (or not)
86-
void setTracingTypeEnabled(unsigned int EventTy, bool Enable);
86+
/// Check if the given device is monitoring the provided tracing type
87+
bool isTracingTypeEnabled(int DeviceId, unsigned int EventTy);
88+
89+
/// Check if the given device is monitoring the provided tracing type 'group'
90+
/// Where group means we will check for both: EMI and non-EMI event types
91+
bool isTracingTypeGroupEnabled(int DeviceId, unsigned int EventTy);
92+
93+
/// Set whether the given tracing type should be monitored (or not) on the
94+
/// device
95+
void setTracingTypeEnabled(uint64_t &TracedEventTy, bool Enable,
96+
unsigned int EventTy);
8797

8898
/// Set / reset the given tracing types (EventTy = 0 corresponds to 'all')
89-
ompt_set_result_t setTraceEventTy(ompt_device_t *Device, unsigned int Enable,
99+
ompt_set_result_t setTraceEventTy(int DeviceId, unsigned int Enable,
90100
unsigned int EventTy);
91101

92102
/// Return thread id
93103
uint64_t getThreadId();
94104

95-
// Mutexes to serialize invocation of device-independent entry points
105+
/// Mutexes to serialize invocation of device registration and checks
106+
extern std::mutex DeviceAccessMutex;
107+
108+
/// Mutexes to serialize invocation of device-independent entry points
96109
extern std::mutex TraceAccessMutex;
97110
extern std::mutex TraceControlMutex;
98111

99-
// Ensure serialization of calls to std::hash
112+
/// Ensure serialization of calls to std::hash
100113
extern std::mutex TraceHashThreadMutex;
101114

102-
// Protect map from device-id to the corresponding buffer-request and
103-
// buffer-completion callback functions.
115+
/// Protect map from device-id to the corresponding buffer-request and
116+
/// buffer-completion callback functions.
104117
extern std::mutex BufferManagementFnMutex;
105118

106-
// Map from device-id to the corresponding buffer-request and buffer-completion
107-
// callback functions.
119+
/// Map from device-id to the corresponding buffer-request and buffer-completion
120+
/// callback functions.
108121
extern std::unordered_map<int, std::pair<ompt_callback_buffer_request_t,
109122
ompt_callback_buffer_complete_t>>
110123
BufferManagementFns;
111124

112-
// Thread local variables used by the plugin to communicate OMPT information
113-
// that are then used to populate trace records. This method assumes a
114-
// synchronous implementation, otherwise it won't work.
125+
/// Thread local variables used by the plugin to communicate OMPT information
126+
/// that are then used to populate trace records. This method assumes a
127+
/// synchronous implementation, otherwise it won't work.
115128
extern thread_local uint32_t TraceRecordNumGrantedTeams;
116129
extern thread_local uint64_t TraceRecordStartTime;
117130
extern thread_local uint64_t TraceRecordStopTime;
118131

119-
// Thread local thread-id.
132+
/// Thread local thread-id.
120133
extern thread_local uint64_t ThreadId;
121134

122-
// Manage all tracing records in one place
135+
/// Manage all tracing records in one place.
123136
extern OmptTracingBufferMgr TraceRecordManager;
124137

125-
// Keep track of enabled tracing event types
126-
extern std::atomic<uint64_t> TracingTypesEnabled;
127-
128-
/// OMPT tracing status; (Re-)Set via 'setTracingState'
138+
/// OMPT global tracing status. Indicates if at least one device is traced.
129139
extern bool TracingActive;
130140

131141
} // namespace ompt

offload/include/OpenMP/OMPT/Interface.h

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,9 @@
1515

1616
// Only provide functionality if target OMPT support is enabled
1717
#ifdef OMPT_SUPPORT
18-
#include <functional>
19-
#include <tuple>
20-
21-
#include "Shared/Debug.h"
22-
2318
#include "Callback.h"
2419
#include "Shared/APITypes.h"
20+
#include "Shared/Debug.h"
2521
#include "omp-tools.h"
2622

2723
#include "llvm/Support/ErrorHandling.h"
@@ -52,8 +48,12 @@ extern ompt_get_task_data_t ompt_get_task_data_fn;
5248
extern ompt_get_target_task_data_t ompt_get_target_task_data_fn;
5349
extern ompt_set_frame_enter_t ompt_set_frame_enter_fn;
5450

51+
/// OMPT global tracing status. Indicates if at least one device is traced.
5552
extern bool TracingActive;
5653

54+
/// Check if this device traces the given event type
55+
extern bool isTracingEnabled(int DeviceId, unsigned int EventTy);
56+
5757
/// Used to maintain execution state for this thread
5858
class Interface {
5959
public:
@@ -191,7 +191,6 @@ class Interface {
191191
unsigned int NumTeams = 1);
192192

193193
ompt_record_ompt_t *stopTargetSubmitTraceAsync(ompt_record_ompt_t *DataPtr,
194-
int64_t DeviceId,
195194
unsigned int NumTeams,
196195
uint64_t NanosStart,
197196
uint64_t NanosStop);
@@ -485,17 +484,17 @@ struct OmptEventInfoTy {
485484
/// extends the original with async capabilities. That is: It takes an
486485
/// additional AsyncInfo reference as argument to populate the relevant fields.
487486
/// The AsyncInfoTy propagates the info into the RTL / plugins.
487+
/// TracedDeviceId represents the trace record's device affinity. EventType is
488+
/// the callback type that needs to be enabled via ompt_set_trace_ompt.
488489
template <typename FunctionPairTy, typename AsyncInfoTy, typename... ArgsTy>
489490
class TracerInterfaceRAII {
490491
public:
491492
TracerInterfaceRAII(FunctionPairTy Callbacks, AsyncInfoTy &AsyncInfo,
493+
int TracedDeviceId, ompt_callbacks_t EventType,
492494
ArgsTy... Args)
493495
: Arguments(Args...), beginFunction(std::get<0>(Callbacks)) {
494496
__tgt_async_info *AI = AsyncInfo;
495-
if (!llvm::omp::target::ompt::TracingActive) {
496-
assert(AI->OmptEventInfo == nullptr &&
497-
"The OmptEventInfo was not nullptr");
498-
} else {
497+
if (isTracingEnabled(TracedDeviceId, EventType)) {
499498
auto Record = begin();
500499
// Gets freed in interface.cpp, functions
501500
// targetKernel and targetData once launching target operations returns.
@@ -505,6 +504,9 @@ class TracerInterfaceRAII {
505504
AI->OmptEventInfo->NumTeams = 0;
506505
AI->OmptEventInfo->RegionInterface = &RegionInterface;
507506
AI->OmptEventInfo->RIFunction = std::get<1>(Callbacks);
507+
} else {
508+
// Actively prevent further tracing of this event
509+
AI->OmptEventInfo = nullptr;
508510
}
509511
}
510512

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1913,7 +1913,7 @@ struct AMDGPUStreamTy {
19131913
"Invalid RegionInterface pointer in OMPT profiling");
19141914
assert(OmptEventInfo.TraceRecord && "Invalid TraceRecord");
19151915
std::invoke(RIFunc, OmptEventInfo.RegionInterface,
1916-
OmptEventInfo.TraceRecord, 0, OmptEventInfo.NumTeams, StartTime,
1916+
OmptEventInfo.TraceRecord, OmptEventInfo.NumTeams, StartTime,
19171917
EndTime);
19181918

19191919
return Plugin::success();

offload/plugins-nextgen/common/OMPT/OmptDeviceTracing.h

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ FOREACH_OMPT_DEVICE_TRACING_FN_IMPLEMENTAIONS(declareOmptTracingFnMutex)
4848

4949
extern std::mutex DeviceIdWritingMutex;
5050

51-
/// Activate / deactivate tracing
52-
void setTracingState(bool Enabled);
51+
/// Activate tracing on the given device
52+
void enableDeviceTracing(int DeviceId);
53+
54+
/// Deactivate tracing on the given device
55+
void disableDeviceTracing(int DeviceId);
5356

5457
/// Set 'start' and 'stop' in trace records
5558
void setOmptTimestamp(uint64_t StartTime, uint64_t EndTime);
@@ -69,6 +72,9 @@ void setDeviceId(ompt_device_t *Device, int32_t DeviceId);
6972
/// Rempve the given device pointer from the current mapping
7073
void removeDeviceId(ompt_device_t *Device);
7174

75+
/// Check whether the provided device is currently traced.
76+
bool isTracedDevice(int32_t DeviceId);
77+
7278
/// Provide name based lookup for the device tracing functions
7379
extern ompt_interface_fn_t
7480
lookupDeviceTracingFn(const char *InterfaceFunctionName);
@@ -82,10 +88,11 @@ extern double HostToDeviceOffset;
8288
/// Mapping of device pointers to their corresponding RTL device ID
8389
extern std::map<ompt_device_t *, int32_t> Devices;
8490

85-
// Keep track of enabled tracing event types
86-
extern std::atomic<uint64_t> TracingTypesEnabled;
91+
/// Mapping of RTL device IDs to their currently enabled tracing event types.
92+
/// Note: Event type '0' (bit position) indicates if this device is traced.
93+
extern std::map<int32_t, uint64_t> TracedDevices;
8794

88-
/// OMPT tracing status; (Re-)Set via 'setTracingState'
95+
/// OMPT global tracing status. Indicates if at least one device is traced.
8996
extern bool TracingActive;
9097

9198
/// Parent library pointer

offload/plugins-nextgen/common/OMPT/OmptTracing.cpp

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,43 +78,61 @@ int llvm::omp::target::ompt::getDeviceId(ompt_device_t *Device) {
7878
std::unique_lock<std::mutex> Lock(DeviceIdWritingMutex);
7979
auto DeviceIterator = Devices.find(Device);
8080
if (Device == nullptr || DeviceIterator == Devices.end()) {
81-
REPORT("Failed to get ID for device=%p\n", Device);
81+
REPORT("Failed to get ID for Device=%p\n", Device);
8282
return -1;
8383
}
8484
return DeviceIterator->second;
8585
}
8686

8787
void llvm::omp::target::ompt::setDeviceId(ompt_device_t *Device,
8888
int32_t DeviceId) {
89-
assert(Device && "Mapping device id to nullptr is not allowed");
90-
if (Device == nullptr) {
91-
REPORT("Failed to set ID for nullptr device\n");
89+
assert(Device && "Mapping device ID to nullptr is not allowed");
90+
if (Device == nullptr || DeviceId < 0) {
91+
REPORT("Failed to set ID=%d for Device=%p\n", DeviceId, Device);
9292
return;
9393
}
9494
std::unique_lock<std::mutex> Lock(DeviceIdWritingMutex);
95+
auto DeviceIterator = Devices.find(Device);
96+
if (DeviceIterator != Devices.end()) {
97+
auto CurrentDeviceId = DeviceIterator->second;
98+
if (DeviceId == CurrentDeviceId)
99+
REPORT("Tried to duplicate OMPT Device=%p (ID=%d)\n", Device, DeviceId);
100+
else
101+
REPORT("Tried to overwrite OMPT Device=%p (ID=%d with new ID=%d)\n",
102+
Device, CurrentDeviceId, DeviceId);
103+
return;
104+
}
95105
Devices.emplace(Device, DeviceId);
96106
}
97107

98108
void llvm::omp::target::ompt::removeDeviceId(ompt_device_t *Device) {
99-
if (Device == nullptr) {
100-
REPORT("Failed to remove ID for nullptr device\n");
109+
int DeviceId = getDeviceId(Device);
110+
if (DeviceId < 0) {
111+
REPORT("Failed to remove Device=%p (ID=%d)\n", Device, DeviceId);
101112
return;
102113
}
103114
std::unique_lock<std::mutex> Lock(DeviceIdWritingMutex);
104115
Devices.erase(Device);
116+
TracedDevices.erase(DeviceId);
105117
}
106118

107119
OMPT_API_ROUTINE ompt_set_result_t ompt_set_trace_ompt(ompt_device_t *Device,
108120
unsigned int Enable,
109121
unsigned int EventTy) {
110122
DP("Executing ompt_set_trace_ompt\n");
111123

112-
// TODO handle device
124+
int DeviceId = getDeviceId(Device);
125+
if (DeviceId < 0) {
126+
REPORT("Failed to set trace events for Device=%p (Unknown device)\n",
127+
Device);
128+
return ompt_set_never;
129+
}
130+
113131
std::unique_lock<std::mutex> Lock(ompt_set_trace_ompt_mutex);
114132
ensureFuncPtrLoaded<libomptarget_ompt_set_trace_ompt_t>(
115133
"libomptarget_ompt_set_trace_ompt", &ompt_set_trace_ompt_fn);
116134
assert(ompt_set_trace_ompt_fn && "libomptarget_ompt_set_trace_ompt loaded");
117-
return ompt_set_trace_ompt_fn(Device, Enable, EventTy);
135+
return ompt_set_trace_ompt_fn(DeviceId, Enable, EventTy);
118136
}
119137

120138
OMPT_API_ROUTINE int
@@ -123,12 +141,18 @@ ompt_start_trace(ompt_device_t *Device, ompt_callback_buffer_request_t Request,
123141
DP("Executing ompt_start_trace\n");
124142

125143
int DeviceId = getDeviceId(Device);
144+
if (DeviceId < 0) {
145+
REPORT("Failed to start trace for Device=%p (Unknown device)\n", Device);
146+
// Indicate failure
147+
return 0;
148+
}
149+
126150
{
127151
// Protect the function pointer
128152
std::unique_lock<std::mutex> Lock(ompt_start_trace_mutex);
129153

130154
if (Request && Complete) {
131-
llvm::omp::target::ompt::setTracingState(/*Enabled=*/true);
155+
llvm::omp::target::ompt::enableDeviceTracing(DeviceId);
132156
// Enable asynchronous memory copy profiling
133157
setOmptAsyncCopyProfile(/*Enable=*/true);
134158
// Enable queue dispatch profiling
@@ -150,7 +174,6 @@ ompt_start_trace(ompt_device_t *Device, ompt_callback_buffer_request_t Request,
150174
OMPT_API_ROUTINE int ompt_flush_trace(ompt_device_t *Device) {
151175
DP("Executing ompt_flush_trace\n");
152176

153-
// TODO handle device
154177
std::unique_lock<std::mutex> Lock(ompt_flush_trace_mutex);
155178
ensureFuncPtrLoaded<libomptarget_ompt_flush_trace_t>(
156179
"libomptarget_ompt_flush_trace", &ompt_flush_trace_fn);
@@ -161,15 +184,20 @@ OMPT_API_ROUTINE int ompt_flush_trace(ompt_device_t *Device) {
161184
OMPT_API_ROUTINE int ompt_stop_trace(ompt_device_t *Device) {
162185
DP("Executing ompt_stop_trace\n");
163186

164-
// TODO handle device
187+
int DeviceId = getDeviceId(Device);
188+
if (DeviceId < 0) {
189+
REPORT("Failed to stop trace for Device=%p (Unknown device)\n", Device);
190+
// Indicate failure
191+
return 0;
192+
}
193+
165194
{
166195
// Protect the function pointer
167196
std::unique_lock<std::mutex> Lock(ompt_stop_trace_mutex);
168-
llvm::omp::target::ompt::setTracingState(/*Enabled=*/false);
197+
llvm::omp::target::ompt::disableDeviceTracing(DeviceId);
169198
// Disable asynchronous memory copy profiling
170199
setOmptAsyncCopyProfile(/*Enable=*/false);
171200
// Disable queue dispatch profiling
172-
int DeviceId = getDeviceId(Device);
173201
if (DeviceId >= 0)
174202
setGlobalOmptKernelProfile(Device, /*Enable=*/0);
175203
else
@@ -179,7 +207,7 @@ OMPT_API_ROUTINE int ompt_stop_trace(ompt_device_t *Device) {
179207
"libomptarget_ompt_stop_trace", &ompt_stop_trace_fn);
180208
assert(ompt_stop_trace_fn && "libomptarget_ompt_stop_trace loaded");
181209
}
182-
return ompt_stop_trace_fn(getDeviceId(Device));
210+
return ompt_stop_trace_fn(DeviceId);
183211
}
184212

185213
OMPT_API_ROUTINE ompt_record_ompt_t *

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -675,10 +675,15 @@ Error GenericKernelTy::launch(GenericDeviceTy &GenericDevice, void **ArgPtrs,
675675
printLaunchInfo(GenericDevice, KernelArgs, NumThreads, NumBlocks))
676676
return Err;
677677

678-
OMPT_IF_TRACING_ENABLED(setOmptGrantedNumTeams(NumBlocks);
679-
// Set number of granted teams for OMPT
680-
__tgt_async_info *AI = AsyncInfoWrapper;
681-
AI->OmptEventInfo->NumTeams = NumBlocks;);
678+
OMPT_IF_TRACING_ENABLED(if (llvm::omp::target::ompt::isTracedDevice(
679+
getDeviceId(&GenericDevice))) {
680+
__tgt_async_info *AI = AsyncInfoWrapper;
681+
if (AI->OmptEventInfo != nullptr) {
682+
// Set number of granted teams for OMPT
683+
setOmptGrantedNumTeams(NumBlocks);
684+
AI->OmptEventInfo->NumTeams = NumBlocks;
685+
}
686+
});
682687

683688
return launchImpl(GenericDevice, NumThreads, NumBlocks, KernelArgs,
684689
LaunchParams, AsyncInfoWrapper);
@@ -871,7 +876,7 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
871876

872877
#ifdef OMPT_SUPPORT
873878
auto DevicePtr = reinterpret_cast<ompt_device_t *>(this);
874-
ompt::setDeviceId(DevicePtr, DeviceId);
879+
ompt::setDeviceId(DevicePtr, Plugin.getUserId(DeviceId));
875880
if (ompt::CallbacksInitialized) {
876881
bool ExpectedStatus = false;
877882
if (OmptInitialized.compare_exchange_strong(ExpectedStatus, true))

0 commit comments

Comments
 (0)