Skip to content

Commit ebdcb2a

Browse files
author
git apple-llvm automerger
committed
Merge commit 'da3dbbb616a2' from apple/main into swift/next
2 parents b3686ea + da3dbbb commit ebdcb2a

File tree

5 files changed

+83
-61
lines changed

5 files changed

+83
-61
lines changed

openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99
using core::atl_is_atmi_initialized;
1010

11-
atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
12-
const char *symbol,
13-
void **var_addr,
14-
unsigned int *var_size) {
11+
atmi_status_t atmi_interop_hsa_get_symbol_info(
12+
const std::map<std::string, atl_symbol_info_t> &SymbolInfoTable,
13+
atmi_mem_place_t place, const char *symbol, void **var_addr,
14+
unsigned int *var_size) {
1515
/*
1616
// Typical usage:
1717
void *var_addr;
@@ -32,9 +32,9 @@ atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
3232

3333
// get the symbol info
3434
std::string symbolStr = std::string(symbol);
35-
if (SymbolInfoTable[place.dev_id].find(symbolStr) !=
36-
SymbolInfoTable[place.dev_id].end()) {
37-
atl_symbol_info_t info = SymbolInfoTable[place.dev_id][symbolStr];
35+
auto It = SymbolInfoTable.find(symbolStr);
36+
if (It != SymbolInfoTable.end()) {
37+
atl_symbol_info_t info = It->second;
3838
*var_addr = reinterpret_cast<void *>(info.addr);
3939
*var_size = info.size;
4040
return ATMI_STATUS_SUCCESS;
@@ -46,6 +46,7 @@ atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
4646
}
4747

4848
atmi_status_t atmi_interop_hsa_get_kernel_info(
49+
const std::map<std::string, atl_kernel_info_t> &KernelInfoTable,
4950
atmi_mem_place_t place, const char *kernel_name,
5051
hsa_executable_symbol_info_t kernel_info, uint32_t *value) {
5152
/*
@@ -68,9 +69,9 @@ atmi_status_t atmi_interop_hsa_get_kernel_info(
6869
atmi_status_t status = ATMI_STATUS_SUCCESS;
6970
// get the kernel info
7071
std::string kernelStr = std::string(kernel_name);
71-
if (KernelInfoTable[place.dev_id].find(kernelStr) !=
72-
KernelInfoTable[place.dev_id].end()) {
73-
atl_kernel_info_t info = KernelInfoTable[place.dev_id][kernelStr];
72+
auto It = KernelInfoTable.find(kernelStr);
73+
if (It != KernelInfoTable.end()) {
74+
atl_kernel_info_t info = It->second;
7475
switch (kernel_info) {
7576
case HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_GROUP_SEGMENT_SIZE:
7677
*value = info.group_segment_size;

openmp/libomptarget/plugins/amdgpu/impl/atmi_interop_hsa.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
#include "atmi_runtime.h"
1010
#include "hsa.h"
1111
#include "hsa_ext_amd.h"
12+
#include "internal.h"
13+
14+
#include <map>
15+
#include <string>
1216

1317
#ifdef __cplusplus
1418
extern "C" {
@@ -44,11 +48,10 @@ extern "C" {
4448
*
4549
* @retval ::ATMI_STATUS_UNKNOWN The function encountered errors.
4650
*/
47-
atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
48-
const char *symbol,
49-
void **var_addr,
50-
unsigned int *var_size);
51-
51+
atmi_status_t atmi_interop_hsa_get_symbol_info(
52+
const std::map<std::string, atl_symbol_info_t> &SymbolInfoTable,
53+
atmi_mem_place_t place, const char *symbol, void **var_addr,
54+
unsigned int *var_size);
5255
/**
5356
* @brief Get the HSA-specific kernel info from a kernel name
5457
*
@@ -75,8 +78,10 @@ atmi_status_t atmi_interop_hsa_get_symbol_info(atmi_mem_place_t place,
7578
* @retval ::ATMI_STATUS_UNKNOWN The function encountered errors.
7679
*/
7780
atmi_status_t atmi_interop_hsa_get_kernel_info(
81+
const std::map<std::string, atl_kernel_info_t> &KernelInfoTable,
7882
atmi_mem_place_t place, const char *kernel_name,
7983
hsa_executable_symbol_info_t info, uint32_t *value);
84+
8085
/** @} */
8186

8287
#ifdef __cplusplus

openmp/libomptarget/plugins/amdgpu/impl/internal.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,6 @@ typedef struct atl_symbol_info_s {
106106
uint32_t size;
107107
} atl_symbol_info_t;
108108

109-
extern std::vector<std::map<std::string, atl_kernel_info_t>> KernelInfoTable;
110-
extern std::vector<std::map<std::string, atl_symbol_info_t>> SymbolInfoTable;
111-
112109
// ---------------------- Kernel End -------------
113110

114111
namespace core {

openmp/libomptarget/plugins/amdgpu/impl/system.cpp

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,6 @@ ATLMachine g_atl_machine;
146146

147147
std::vector<hsa_amd_memory_pool_t> atl_gpu_kernarg_pools;
148148

149-
std::vector<std::map<std::string, atl_kernel_info_t>> KernelInfoTable;
150-
std::vector<std::map<std::string, atl_symbol_info_t>> SymbolInfoTable;
151-
152149
bool g_atmi_initialized = false;
153150

154151
/*
@@ -208,15 +205,6 @@ atmi_status_t Runtime::Initialize() {
208205

209206
atmi_status_t Runtime::Finalize() {
210207
atmi_status_t rc = ATMI_STATUS_SUCCESS;
211-
for (uint32_t i = 0; i < SymbolInfoTable.size(); i++) {
212-
SymbolInfoTable[i].clear();
213-
}
214-
SymbolInfoTable.clear();
215-
for (uint32_t i = 0; i < KernelInfoTable.size(); i++) {
216-
KernelInfoTable[i].clear();
217-
}
218-
KernelInfoTable.clear();
219-
220208
atl_reset_atmi_initialized();
221209
hsa_status_t err = hsa_shut_down();
222210
if (err != HSA_STATUS_SUCCESS) {
@@ -556,13 +544,6 @@ hsa_status_t init_hsa() {
556544
return err;
557545
}
558546

559-
int gpu_count = g_atl_machine.processorCount<ATLGPUProcessor>();
560-
KernelInfoTable.resize(gpu_count);
561-
SymbolInfoTable.resize(gpu_count);
562-
for (uint32_t i = 0; i < SymbolInfoTable.size(); i++)
563-
SymbolInfoTable[i].clear();
564-
for (uint32_t i = 0; i < KernelInfoTable.size(); i++)
565-
KernelInfoTable[i].clear();
566547
atlc.g_hsa_initialized = true;
567548
DEBUG_PRINT("done\n");
568549
}
@@ -835,8 +816,9 @@ int populate_kernelArgMD(msgpack::byte_range args_element,
835816
}
836817
} // namespace
837818

838-
static hsa_status_t get_code_object_custom_metadata(void *binary,
839-
size_t binSize, int gpu) {
819+
static hsa_status_t get_code_object_custom_metadata(
820+
void *binary, size_t binSize, int gpu,
821+
std::map<std::string, atl_kernel_info_t> &KernelInfoTable) {
840822
// parse code object with different keys from v2
841823
// also, the kernel name is not the same as the symbol name -- so a
842824
// symbol->name map is needed
@@ -1003,14 +985,16 @@ static hsa_status_t get_code_object_custom_metadata(void *binary,
1003985
kernel_segment_size, info.kernel_segment_size);
1004986

1005987
// kernel received, now add it to the kernel info table
1006-
KernelInfoTable[gpu][kernelName] = info;
988+
KernelInfoTable[kernelName] = info;
1007989
}
1008990

1009991
return HSA_STATUS_SUCCESS;
1010992
}
1011993

1012-
static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol,
1013-
int gpu) {
994+
static hsa_status_t
995+
populate_InfoTables(hsa_executable_symbol_t symbol, int gpu,
996+
std::map<std::string, atl_kernel_info_t> &KernelInfoTable,
997+
std::map<std::string, atl_symbol_info_t> &SymbolInfoTable) {
1014998
hsa_symbol_kind_t type;
1015999

10161000
uint32_t name_length;
@@ -1047,11 +1031,16 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol,
10471031
// by now, the kernel info table should already have an entry
10481032
// because the non-ROCr custom code object parsing is called before
10491033
// iterating over the code object symbols using ROCr
1050-
if (KernelInfoTable[gpu].find(kernelName) == KernelInfoTable[gpu].end()) {
1051-
return HSA_STATUS_ERROR;
1034+
if (KernelInfoTable.find(kernelName) == KernelInfoTable.end()) {
1035+
if (HSA_STATUS_ERROR_INVALID_CODE_OBJECT != HSA_STATUS_SUCCESS) {
1036+
printf("[%s:%d] %s failed: %s\n", __FILE__, __LINE__,
1037+
"Finding the entry kernel info table",
1038+
get_error_string(HSA_STATUS_ERROR_INVALID_CODE_OBJECT));
1039+
exit(1);
1040+
}
10521041
}
10531042
// found, so assign and update
1054-
info = KernelInfoTable[gpu][kernelName];
1043+
info = KernelInfoTable[kernelName];
10551044

10561045
/* Extract dispatch information from the symbol */
10571046
err = hsa_executable_symbol_get_info(
@@ -1089,7 +1078,7 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol,
10891078
info.private_segment_size, info.kernel_segment_size);
10901079

10911080
// assign it back to the kernel info table
1092-
KernelInfoTable[gpu][kernelName] = info;
1081+
KernelInfoTable[kernelName] = info;
10931082
free(name);
10941083
} else if (type == HSA_SYMBOL_KIND_VARIABLE) {
10951084
err = hsa_executable_symbol_get_info(
@@ -1135,15 +1124,17 @@ static hsa_status_t populate_InfoTables(hsa_executable_symbol_t symbol,
11351124
if (err != HSA_STATUS_SUCCESS) {
11361125
return err;
11371126
}
1138-
SymbolInfoTable[gpu][std::string(name)] = info;
1127+
SymbolInfoTable[std::string(name)] = info;
11391128
free(name);
11401129
} else {
11411130
DEBUG_PRINT("Symbol is an indirect function\n");
11421131
}
11431132
return HSA_STATUS_SUCCESS;
11441133
}
11451134

1146-
atmi_status_t Runtime::RegisterModuleFromMemory(
1135+
atmi_status_t RegisterModuleFromMemory(
1136+
std::map<std::string, atl_kernel_info_t> &KernelInfoTable,
1137+
std::map<std::string, atl_symbol_info_t> &SymbolInfoTable,
11471138
void *module_bytes, size_t module_size, atmi_place_t place,
11481139
atmi_status_t (*on_deserialized_data)(void *data, size_t size,
11491140
void *cb_state),
@@ -1183,7 +1174,8 @@ atmi_status_t Runtime::RegisterModuleFromMemory(
11831174
// Some metadata info is not available through ROCr API, so use custom
11841175
// code object metadata parsing to collect such metadata info
11851176

1186-
err = get_code_object_custom_metadata(module_bytes, module_size, gpu);
1177+
err = get_code_object_custom_metadata(module_bytes, module_size, gpu,
1178+
KernelInfoTable);
11871179
if (err != HSA_STATUS_SUCCESS) {
11881180
DEBUG_PRINT("[%s:%d] %s failed: %s\n", __FILE__, __LINE__,
11891181
"Getting custom code object metadata",
@@ -1240,9 +1232,9 @@ atmi_status_t Runtime::RegisterModuleFromMemory(
12401232
err = hsa::executable_iterate_symbols(
12411233
executable,
12421234
[&](hsa_executable_t, hsa_executable_symbol_t symbol) -> hsa_status_t {
1243-
return populate_InfoTables(symbol, gpu);
1235+
return populate_InfoTables(symbol, gpu, KernelInfoTable,
1236+
SymbolInfoTable);
12441237
});
1245-
12461238
if (err != HSA_STATUS_SUCCESS) {
12471239
printf("[%s:%d] %s failed: %s\n", __FILE__, __LINE__,
12481240
"Iterating over symbols for execuatable", get_error_string(err));

openmp/libomptarget/plugins/amdgpu/src/rtl.cpp

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ int print_kernel_trace;
8686

8787
#include "elf_common.h"
8888

89+
namespace core {
90+
atmi_status_t RegisterModuleFromMemory(
91+
std::map<std::string, atl_kernel_info_t> &KernelInfo,
92+
std::map<std::string, atl_symbol_info_t> &SymbolInfoTable, void *, size_t,
93+
atmi_place_t,
94+
atmi_status_t (*on_deserialized_data)(void *data, size_t size,
95+
void *cb_state),
96+
void *cb_state, std::vector<hsa_executable_t> &HSAExecutables);
97+
}
98+
8999
/// Keep entries table per device
90100
struct FuncOrGblEntryTy {
91101
__tgt_target_table Table;
@@ -339,6 +349,9 @@ class RTLDeviceInfoTy {
339349

340350
std::vector<hsa_executable_t> HSAExecutables;
341351

352+
std::vector<std::map<std::string, atl_kernel_info_t>> KernelInfoTable;
353+
std::vector<std::map<std::string, atl_symbol_info_t>> SymbolInfoTable;
354+
342355
struct atmiFreePtrDeletor {
343356
void operator()(void *p) {
344357
atmi_free(p); // ignore failure to free
@@ -482,6 +495,8 @@ class RTLDeviceInfoTy {
482495
NumTeams.resize(NumberOfDevices);
483496
NumThreads.resize(NumberOfDevices);
484497
deviceStateStore.resize(NumberOfDevices);
498+
KernelInfoTable.resize(NumberOfDevices);
499+
SymbolInfoTable.resize(NumberOfDevices);
485500

486501
for (int i = 0; i < NumberOfDevices; i++) {
487502
HSAQueues[i] = nullptr;
@@ -993,15 +1008,17 @@ atmi_status_t interop_get_symbol_info(char *base, size_t img_size,
9931008

9941009
template <typename C>
9951010
atmi_status_t module_register_from_memory_to_place(
1011+
std::map<std::string, atl_kernel_info_t> &KernelInfoTable,
1012+
std::map<std::string, atl_symbol_info_t> &SymbolInfoTable,
9961013
void *module_bytes, size_t module_size, atmi_place_t place, C cb,
9971014
std::vector<hsa_executable_t> &HSAExecutables) {
9981015
auto L = [](void *data, size_t size, void *cb_state) -> atmi_status_t {
9991016
C *unwrapped = static_cast<C *>(cb_state);
10001017
return (*unwrapped)(data, size);
10011018
};
1002-
return core::Runtime::RegisterModuleFromMemory(
1003-
module_bytes, module_size, place, L, static_cast<void *>(&cb),
1004-
HSAExecutables);
1019+
return core::RegisterModuleFromMemory(
1020+
KernelInfoTable, SymbolInfoTable, module_bytes, module_size, place, L,
1021+
static_cast<void *>(&cb), HSAExecutables);
10051022
}
10061023
} // namespace
10071024

@@ -1116,11 +1133,12 @@ struct device_environment {
11161133
DP("Setting global device environment after load (%u bytes)\n",
11171134
si.size);
11181135
int device_id = host_device_env.device_num;
1119-
1136+
auto &SymbolInfo = DeviceInfo.SymbolInfoTable[device_id];
11201137
void *state_ptr;
11211138
uint32_t state_ptr_size;
11221139
atmi_status_t err = atmi_interop_hsa_get_symbol_info(
1123-
get_gpu_mem_place(device_id), sym(), &state_ptr, &state_ptr_size);
1140+
SymbolInfo, get_gpu_mem_place(device_id), sym(), &state_ptr,
1141+
&state_ptr_size);
11241142
if (err != ATMI_STATUS_SUCCESS) {
11251143
DP("failed to find %s in loaded image\n", sym());
11261144
return err;
@@ -1205,8 +1223,11 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
12051223
auto env = device_environment(device_id, DeviceInfo.NumberOfDevices, image,
12061224
img_size);
12071225

1226+
auto &KernelInfo = DeviceInfo.KernelInfoTable[device_id];
1227+
auto &SymbolInfo = DeviceInfo.SymbolInfoTable[device_id];
12081228
atmi_status_t err = module_register_from_memory_to_place(
1209-
(void *)image->ImageStart, img_size, get_gpu_place(device_id),
1229+
KernelInfo, SymbolInfo, (void *)image->ImageStart, img_size,
1230+
get_gpu_place(device_id),
12101231
[&](void *data, size_t size) {
12111232
if (image_contains_symbol(data, size, "needs_hostcall_buffer")) {
12121233
__atomic_store_n(&DeviceInfo.hostcall_required, true,
@@ -1241,9 +1262,10 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
12411262

12421263
void *state_ptr;
12431264
uint32_t state_ptr_size;
1265+
auto &SymbolInfoMap = DeviceInfo.SymbolInfoTable[device_id];
12441266
atmi_status_t err = atmi_interop_hsa_get_symbol_info(
1245-
get_gpu_mem_place(device_id), "omptarget_nvptx_device_State",
1246-
&state_ptr, &state_ptr_size);
1267+
SymbolInfoMap, get_gpu_mem_place(device_id),
1268+
"omptarget_nvptx_device_State", &state_ptr, &state_ptr_size);
12471269

12481270
if (err != ATMI_STATUS_SUCCESS) {
12491271
DP("No device_state symbol found, skipping initialization\n");
@@ -1325,8 +1347,10 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
13251347
void *varptr;
13261348
uint32_t varsize;
13271349

1350+
auto &SymbolInfoMap = DeviceInfo.SymbolInfoTable[device_id];
13281351
atmi_status_t err = atmi_interop_hsa_get_symbol_info(
1329-
get_gpu_mem_place(device_id), e->name, &varptr, &varsize);
1352+
SymbolInfoMap, get_gpu_mem_place(device_id), e->name, &varptr,
1353+
&varsize);
13301354

13311355
if (err != ATMI_STATUS_SUCCESS) {
13321356
// Inform the user what symbol prevented offloading
@@ -1367,8 +1391,10 @@ __tgt_target_table *__tgt_rtl_load_binary_locked(int32_t device_id,
13671391

13681392
atmi_mem_place_t place = get_gpu_mem_place(device_id);
13691393
uint32_t kernarg_segment_size;
1394+
auto &KernelInfoMap = DeviceInfo.KernelInfoTable[device_id];
13701395
atmi_status_t err = atmi_interop_hsa_get_kernel_info(
1371-
place, e->name, HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE,
1396+
KernelInfoMap, place, e->name,
1397+
HSA_EXECUTABLE_SYMBOL_INFO_KERNEL_KERNARG_SEGMENT_SIZE,
13721398
&kernarg_segment_size);
13731399

13741400
// each arg is a void * in this openmp implementation
@@ -1794,6 +1820,7 @@ int32_t __tgt_rtl_run_target_team_region_locked(
17941820
KernelTy *KernelInfo = (KernelTy *)tgt_entry_ptr;
17951821

17961822
std::string kernel_name = std::string(KernelInfo->Name);
1823+
auto &KernelInfoTable = DeviceInfo.KernelInfoTable;
17971824
if (KernelInfoTable[device_id].find(kernel_name) ==
17981825
KernelInfoTable[device_id].end()) {
17991826
DP("Kernel %s not found\n", kernel_name.c_str());

0 commit comments

Comments
 (0)