Skip to content

Commit 8cdd099

Browse files
authored
Merge pull request intel#954 from jchlanda/jakub/rqwgs_hip
[HIP] Handle required wg size attribute in HIP
2 parents fc9bb61 + c893a3c commit 8cdd099

File tree

8 files changed

+62
-41
lines changed

8 files changed

+62
-41
lines changed

source/adapters/cuda/program.cpp

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "program.hpp"
12+
#include "ur_util.hpp"
1213

1314
bool getMaxRegistersJitOptionValue(const std::string &BuildOptions,
1415
unsigned int &Value) {
@@ -52,15 +53,6 @@ ur_program_handle_t_::ur_program_handle_t_(ur_context_handle_t Context)
5253

5354
ur_program_handle_t_::~ur_program_handle_t_() { urContextRelease(Context); }
5455

55-
std::pair<std::string, std::string>
56-
splitMetadataName(const std::string &metadataName) {
57-
size_t splitPos = metadataName.rfind('@');
58-
if (splitPos == std::string::npos)
59-
return std::make_pair(metadataName, std::string{});
60-
return std::make_pair(metadataName.substr(0, splitPos),
61-
metadataName.substr(splitPos, metadataName.length()));
62-
}
63-
6456
ur_result_t
6557
ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
6658
size_t Length) {

source/adapters/hip/enqueue.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,10 +1850,14 @@ setKernelParams(const ur_device_handle_t Device, const uint32_t WorkDim,
18501850
static_cast<size_t>(Device->getMaxBlockDimY()),
18511851
static_cast<size_t>(Device->getMaxBlockDimZ())};
18521852

1853+
auto &ReqdThreadsPerBlock = Kernel->ReqdThreadsPerBlock;
18531854
MaxWorkGroupSize = Device->getMaxWorkGroupSize();
18541855

18551856
if (LocalWorkSize != nullptr) {
18561857
auto isValid = [&](int dim) {
1858+
UR_ASSERT(ReqdThreadsPerBlock[dim] == 0 ||
1859+
LocalWorkSize[dim] == ReqdThreadsPerBlock[dim],
1860+
UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
18571861
UR_ASSERT(LocalWorkSize[dim] <= MaxThreadsPerBlock[dim],
18581862
UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
18591863
// Checks that local work sizes are a divisor of the global work sizes

source/adapters/hip/kernel.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,17 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
9191
return ReturnValue(size_t(MaxThreads));
9292
}
9393
case UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE: {
94-
size_t group_size[3] = {0, 0, 0};
95-
// Returns the work-group size specified in the kernel source or IL.
96-
// If the work-group size is not specified in the kernel source or IL,
97-
// (0, 0, 0) is returned.
98-
// https://www.khronos.org/registry/OpenCL/sdk/2.1/docs/man/xhtml/clGetKernelWorkGroupInfo.html
99-
100-
// TODO: can we extract the work group size from the PTX?
101-
return ReturnValue(group_size, 3);
94+
size_t GroupSize[3] = {0, 0, 0};
95+
const auto &ReqdWGSizeMDMap =
96+
hKernel->getProgram()->KernelReqdWorkGroupSizeMD;
97+
const auto ReqdWGSizeMD = ReqdWGSizeMDMap.find(hKernel->getName());
98+
if (ReqdWGSizeMD != ReqdWGSizeMDMap.end()) {
99+
const auto ReqdWGSize = ReqdWGSizeMD->second;
100+
GroupSize[0] = std::get<0>(ReqdWGSize);
101+
GroupSize[1] = std::get<1>(ReqdWGSize);
102+
GroupSize[2] = std::get<2>(ReqdWGSize);
103+
}
104+
return ReturnValue(GroupSize, 3);
102105
}
103106
case UR_KERNEL_GROUP_INFO_LOCAL_MEM_SIZE: {
104107
// OpenCL LOCAL == HIP SHARED

source/adapters/hip/kernel.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ struct ur_kernel_handle_t_ {
4242
ur_program_handle_t Program;
4343
std::atomic_uint32_t RefCount;
4444

45+
static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u;
46+
size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions];
47+
4548
/// Structure that holds the arguments to the kernel.
4649
/// Note earch argument size is known, since it comes
4750
/// from the kernel signature.
@@ -154,6 +157,11 @@ struct ur_kernel_handle_t_ {
154157
ur_context_handle_t Ctxt)
155158
: Function{Func}, FunctionWithOffsetParam{FuncWithOffsetParam},
156159
Name{Name}, Context{Ctxt}, Program{Program}, RefCount{1} {
160+
assert(Program->getDevice());
161+
UR_CHECK_ERROR(urKernelGetGroupInfo(
162+
this, Program->getDevice(),
163+
UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE,
164+
sizeof(ReqdThreadsPerBlock), ReqdThreadsPerBlock, nullptr));
157165
urProgramRetain(Program);
158166
urContextRetain(Context);
159167
}

source/adapters/hip/program.cpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "program.hpp"
12+
#include "ur_util.hpp"
1213

1314
#ifdef SYCL_ENABLE_KERNEL_FUSION
1415
#ifdef UR_COMGR_VERSION4_INCLUDE
@@ -78,15 +79,6 @@ void getCoMgrBuildLog(const amd_comgr_data_set_t BuildDataSet, char *BuildLog,
7879
} // namespace
7980
#endif
8081

81-
std::pair<std::string, std::string>
82-
splitMetadataName(const std::string &metadataName) {
83-
size_t splitPos = metadataName.rfind('@');
84-
if (splitPos == std::string::npos)
85-
return std::make_pair(metadataName, std::string{});
86-
return std::make_pair(metadataName.substr(0, splitPos),
87-
metadataName.substr(splitPos, metadataName.length()));
88-
}
89-
9082
ur_result_t
9183
ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
9284
size_t Length) {
@@ -107,8 +99,29 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
10799
const char *MetadataValPtrEnd =
108100
MetadataValPtr + MetadataElement.size - sizeof(std::uint64_t);
109101
GlobalIDMD[Prefix] = std::string{MetadataValPtr, MetadataValPtrEnd};
102+
} else if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE) {
103+
// If metadata is reqd_work_group_size, record it for the corresponding
104+
// kernel name.
105+
size_t MDElemsSize = MetadataElement.size - sizeof(std::uint64_t);
106+
107+
// Expect between 1 and 3 32-bit integer values.
108+
UR_ASSERT(MDElemsSize >= sizeof(std::uint32_t) &&
109+
MDElemsSize <= sizeof(std::uint32_t) * 3,
110+
UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
111+
112+
// Get pointer to data, skipping 64-bit size at the start of the data.
113+
const char *ValuePtr =
114+
reinterpret_cast<const char *>(MetadataElement.value.pData) +
115+
sizeof(std::uint64_t);
116+
// Read values and pad with 1's for values not present.
117+
std::uint32_t ReqdWorkGroupElements[] = {1, 1, 1};
118+
std::memcpy(ReqdWorkGroupElements, ValuePtr, MDElemsSize);
119+
KernelReqdWorkGroupSizeMD[Prefix] =
120+
std::make_tuple(ReqdWorkGroupElements[0], ReqdWorkGroupElements[1],
121+
ReqdWorkGroupElements[2]);
110122
}
111123
}
124+
112125
return UR_RESULT_SUCCESS;
113126
}
114127

@@ -459,8 +472,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
459472
std::unique_ptr<ur_program_handle_t_> RetProgram{
460473
new ur_program_handle_t_{hContext, hDevice}};
461474

462-
// TODO: Set metadata here and use reqd_work_group_size information.
463-
// See urProgramCreateWithBinary in CUDA adapter.
464475
if (pProperties) {
465476
if (pProperties->count > 0 && pProperties->pMetadatas == nullptr) {
466477
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
@@ -469,8 +480,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
469480
}
470481
Result =
471482
RetProgram->setMetadata(pProperties->pMetadatas, pProperties->count);
483+
UR_ASSERT(Result == UR_RESULT_SUCCESS, Result);
472484
}
473-
UR_ASSERT(Result == UR_RESULT_SUCCESS, Result);
474485

475486
auto pBinary_string = reinterpret_cast<const char *>(pBinary);
476487
if (size == 0) {

source/adapters/hip/program.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ur_api.h>
1313

1414
#include <atomic>
15+
#include <unordered_map>
1516

1617
#include "context.hpp"
1718

@@ -30,6 +31,8 @@ struct ur_program_handle_t_ {
3031
bool IsRelocatable = false;
3132

3233
std::unordered_map<std::string, std::string> GlobalIDMD;
34+
std::unordered_map<std::string, std::tuple<uint32_t, uint32_t, uint32_t>>
35+
KernelReqdWorkGroupSizeMD;
3336

3437
constexpr static size_t MAX_LOG_SIZE = 8192u;
3538

@@ -38,8 +41,8 @@ struct ur_program_handle_t_ {
3841
ur_program_build_status_t BuildStatus = UR_PROGRAM_BUILD_STATUS_NONE;
3942

4043
ur_program_handle_t_(ur_context_handle_t Ctxt, ur_device_handle_t Device)
41-
: Module{nullptr}, Binary{},
42-
BinarySizeInBytes{0}, RefCount{1}, Context{Ctxt}, Device{Device} {
44+
: Module{nullptr}, Binary{}, BinarySizeInBytes{0}, RefCount{1},
45+
Context{Ctxt}, Device{Device}, KernelReqdWorkGroupSizeMD{} {
4346
urContextRetain(Context);
4447
urDeviceRetain(Device);
4548
}

source/adapters/native_cpu/program.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "ur_api.h"
1212

1313
#include "common.hpp"
14+
#include "common/ur_util.hpp"
1415
#include "program.hpp"
1516
#include <cstdint>
1617

@@ -27,16 +28,6 @@ urProgramCreateWithIL(ur_context_handle_t hContext, const void *pIL,
2728
DIE_NO_IMPLEMENTATION
2829
}
2930

30-
// TODO: taken from CUDA adapter, move this to a common header?
31-
static std::pair<std::string, std::string>
32-
splitMetadataName(const std::string &metadataName) {
33-
size_t splitPos = metadataName.rfind('@');
34-
if (splitPos == std::string::npos)
35-
return std::make_pair(metadataName, std::string{});
36-
return std::make_pair(metadataName.substr(0, splitPos),
37-
metadataName.substr(splitPos, metadataName.length()));
38-
}
39-
4031
static ur_result_t getReqdWGSize(const ur_program_metadata_t &MetadataElement,
4132
native_cpu::ReqdWGSize_t &res) {
4233
size_t MDElemsSize = MetadataElement.size - sizeof(std::uint64_t);

source/common/ur_util.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,4 +334,13 @@ namespace ur {
334334
}
335335
} // namespace ur
336336

337+
inline std::pair<std::string, std::string>
338+
splitMetadataName(const std::string &metadataName) {
339+
size_t splitPos = metadataName.rfind('@');
340+
if (splitPos == std::string::npos) {
341+
return std::make_pair(metadataName, std::string{});
342+
}
343+
return std::make_pair(metadataName.substr(0, splitPos),
344+
metadataName.substr(splitPos, metadataName.length()));
345+
}
337346
#endif /* UR_UTIL_H */

0 commit comments

Comments
 (0)