Skip to content

Commit fa380f3

Browse files
committed
[SYCL] Addition of suggested changes from PR#808 and PR#843.
Signed-off-by: Garima Gupta <[email protected]>
1 parent f60cb6c commit fa380f3

File tree

7 files changed

+47
-54
lines changed

7 files changed

+47
-54
lines changed

sycl/include/CL/sycl/detail/pi.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -917,13 +917,18 @@ pi_result piEnqueueMemUnmap(
917917
#define STRING_HELPER(a) #a
918918
#define STRINGIZE(a,b) STRING_HELPER(a.b)
919919

920-
struct pi_plugin {
920+
struct _pi_plugin{
921921
// PI version supported by host passed to the plugin. The Plugin
922922
// checks and writes the appropriate Function Pointers in
923-
// PIFunctionTable.
923+
// PiFunctionTable.
924+
// TODO: Work on version fields and their handshaking mechanism.
925+
// Some choices are:
926+
// - Use of integers to keep major and minor version.
927+
// - Keeping char* Versions.
924928
const char PiVersion[4] = STRINGIZE(_PI_H_VERSION_MAJOR,_PI_H_VERSION_MINOR);
929+
// Plugin edits this.
925930
char PluginVersion[4] =
926-
STRINGIZE(_PI_H_VERSION_MAJOR,_PI_H_VERSION_MINOR); // Plugin edits this.
931+
STRINGIZE(_PI_H_VERSION_MAJOR,_PI_H_VERSION_MINOR);
927932
char *Targets;
928933
struct FunctionPointers {
929934
#define _PI_API(api) decltype(::api) *api;

sycl/include/CL/sycl/detail/pi.hpp

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,9 @@ void printArgs(Arg0 arg0, Args... args) {
119119

120120
// Utility function to check return from pi calls.
121121
// Throws if pi_result is not a PI_SUCCESS.
122-
// TODO: Absorb this utility in Trace Class
123-
template <typename Exception> inline void piCheckThrow(PiResult pi_result) {
124-
CHECK_OCL_CODE_THROW(pi_result, Exception);
125-
}
126-
127-
// Utility function to check if return from pi call is
128-
// PI_SUCCESS. If is it not, throw a cl::sycl::runtime_error.
129-
// TODO: Absorb this utility in Trace Class
122+
template <typename Exception = cl::sycl::runtime_error>
130123
inline void piCheckResult(PiResult pi_result) {
131-
piCheckThrow<cl::sycl::runtime_error>(pi_result);
124+
CHECK_OCL_CODE_THROW(pi_result, Exception);
132125
}
133126

134127
#define PI_TRACE_ENABLED (std::getenv("SYCL_PI_TRACE") != nullptr)
@@ -141,30 +134,32 @@ template <typename FnType, size_t FnOffset> class Trace {
141134
public:
142135
Trace();
143136
template <typename... Args> PiResult operator()(Args... args) {
144-
if (PI_TRACE_ENABLED) {
137+
bool enableTrace = PI_TRACE_ENABLED;
138+
if (enableTrace) {
145139
std::cout << "---> " << m_FnName << "(";
146140
printArgs(args...);
147141
}
148142

149143
PiResult r = m_FnPtr(args...);
150144

151-
if (PI_TRACE_ENABLED) {
145+
if (enableTrace) {
152146
std::cout << ") ---> ";
153147
std::cout << (print(r), "") << std::endl;
154148
}
155149
return r;
156150
}
157151
};
158152

159-
template <typename FnType, size_t FnOffset>
153+
template <typename FnType, size_t FnOffset,
154+
typename Exception = cl::sycl::runtime_error>
160155
class TraceCheck : private Trace<FnType, FnOffset> {
161156
public:
162157
TraceCheck() : Trace<FnType, FnOffset>(){};
163158

164-
template <typename Exception = cl::sycl::runtime_error, typename... Args>
159+
template <typename... Args>
165160
void operator()(Args... args) {
166161
PiResult Err = (Trace<FnType, FnOffset>::operator()(args...));
167-
piCheckThrow<Exception>(Err);
162+
piCheckResult<Exception>(Err);
168163
}
169164
};
170165

@@ -174,7 +169,7 @@ class TraceCheck : private Trace<FnType, FnOffset> {
174169
#define _PI_API(api) \
175170
template <> \
176171
Trace<decltype(&::api), \
177-
(offsetof(_pi_plugin::FunctionPointers, api))>::Trace();
172+
(offsetof(pi_plugin::FunctionPointers, api))>::Trace();
178173

179174
#include <CL/sycl/detail/pi.def>
180175

@@ -187,23 +182,25 @@ namespace RT = cl::sycl::detail::pi;
187182
// Use this macro to call the API, trace the call, check the return and throw a
188183
// runtime_error exception.
189184
// Usage: PI_CALL(pi)(Args);
190-
// Note: To change the exception type, use:
191-
// PI_CALL(pi).template operator()<compile_program_error>(__VA_ARGS__)
192-
// Or
193-
// auto Err = PI_CALL_NOCHECK(pi)(args);
194-
// RT::piCheckThrow<Exception>(Err);
195185
#define PI_CALL(pi) \
196186
RT::TraceCheck<decltype(&::pi), \
197-
(offsetof(_pi_plugin::FunctionPointers, pi))>()
187+
(offsetof(pi_plugin::FunctionPointers, pi))>()
198188

199189
// Use this macro to call the API, trace the call and return the result.
200-
// To check the result use piCheckResult or piCheckThrow.
190+
// To check the result use piCheckResult.
201191
// Usage:
202192
// PiResult Err = PI_CALL_NOCHECK(pi)(args);
203193
// RT::piCheckResult(Err); <- Checks Result and throws a runtime error
204194
// exception.
205195
#define PI_CALL_NOCHECK(pi) \
206-
RT::Trace<decltype(&::pi), (offsetof(_pi_plugin::FunctionPointers, pi))>()
196+
RT::Trace<decltype(&::pi), (offsetof(pi_plugin::FunctionPointers, pi))>()
197+
198+
// Use this macro to call the API, trace the call, check the return and throw a
199+
// Exception as given in the MACRO.
200+
// Usage: PI_CALL_THROW(pi, compile_program_error)(args);
201+
#define PI_CALL_THROW(pi, Exception) \
202+
RT::TraceCheck<decltype(&::pi), \
203+
(offsetof(pi_plugin::FunctionPointers, pi)), Exception>()
207204

208205
// Want all the needed casts be explicit, do not define conversion
209206
// operators.

sycl/include/CL/sycl/detail/program_impl.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,10 @@ class program_impl {
8989
NonInterOpToLink |= !Prg->IsLinkable;
9090
Programs.push_back(Prg->Program);
9191
}
92-
RT::PiResult Err = PI_SUCCESS;
93-
Err = PI_CALL_NOCHECK(piProgramLink)(
92+
PI_CALL_THROW(piProgramLink, compile_program_error)(
9493
detail::getSyclObjImpl(Context)->getHandleRef(), Devices.size(),
9594
Devices.data(), LinkOptions.c_str(), Programs.size(), Programs.data(),
9695
nullptr, nullptr, &Program);
97-
RT::piCheckThrow<compile_program_error>(Err);
9896
}
9997
}
10098

@@ -247,11 +245,10 @@ class program_impl {
247245
check_device_feature_support<
248246
info::device::is_linker_available>(Devices);
249247
vector_class<RT::PiDevice> Devices(get_pi_devices());
250-
RT::PiResult Err = PI_CALL_RESULT(piProgramLink)(
248+
PI_CALL_THROW(piProgramLink, compile_program_error)(
251249
detail::getSyclObjImpl(Context)->getHandleRef(), Devices.size(),
252250
Devices.data(), LinkOptions.c_str(), 1, &Program, nullptr, nullptr,
253251
&Program);
254-
RT::piCheckThrow<compile_program_error>(Err);
255252
this->LinkOptions = LinkOptions;
256253
BuildOptions = LinkOptions;
257254
}

sycl/include/CL/sycl/detail/sycl_mem_obj_t.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,13 @@ template <typename AllocatorT> class SYCLMemObjT : public SYCLMemObjI {
8686

8787
RT::PiMem Mem = pi::cast<RT::PiMem>(MInteropMemObject);
8888
RT::PiContext Context = nullptr;
89-
RT::piCheckResult(PI_CALL_NOCHECK(piMemGetInfo)(
90-
Mem, CL_MEM_CONTEXT, sizeof(Context), &Context, nullptr));
89+
PI_CALL(piMemGetInfo)(
90+
Mem, CL_MEM_CONTEXT, sizeof(Context), &Context, nullptr);
9191

9292
if (MInteropContext->getHandleRef() != Context)
9393
throw cl::sycl::invalid_parameter_error(
9494
"Input context must be the same as the context of cl_mem");
95-
RT::piCheckResult(PI_CALL_NOCHECK(piMemRetain)(Mem));
95+
PI_CALL(piMemRetain)(Mem);
9696
}
9797

9898
SYCLMemObjT(cl_mem MemObject, const context &SyclContext,

sycl/plugins/opencl/pi_opencl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,14 +460,14 @@ pi_result OCL(piEnqueueMemBufferMap)(
460460
}
461461

462462
pi_result piPluginInit(pi_plugin *PluginInit) {
463-
strcpy(PluginInit->PluginVersion, SupportedVersion);
464463
int CompareVersions = strcmp(PluginInit->PiVersion, SupportedVersion);
465464
if (CompareVersions < 0) {
466465
// PI interface supports lower version of PI.
467-
assert(false && "incompatible versions.!!\n");
466+
// TODO: Take appropriate actions.
468467
return PI_INVALID_OPERATION;
469468
} else {
470469
// PI interface supports higher version or the same version.
470+
strncpy(PluginInit->PluginVersion, SupportedVersion, 4);
471471

472472
#define _PI_CL(pi_api, ocl_api) \
473473
(PluginInit->PiFunctionTable).pi_api = (decltype(&::pi_api))(&ocl_api);

sycl/source/detail/pi.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,12 @@ bool bindPlugin(void *Library) {
8080
decltype(::piPluginInit) *PluginInitializeFunction = (decltype(
8181
&::piPluginInit))(getOsLibraryFuncAddress(Library, "piPluginInit"));
8282
int err = PluginInitializeFunction(&PluginInformation);
83-
int CompareVersions =
84-
strcmp(PluginInformation.PiVersion, PluginInformation.PluginVersion);
85-
86-
// CompareVersions >= 0, Plugin Interface supports same/higher PI version as
87-
// the Plugin.
88-
// TODO: When Plugin supports lower version of PI, check for backward
89-
// compatibility.
90-
assert((CompareVersions >= 0) && "Plugin Interface supports lower PI version "
91-
"than Plugin. Update library.");
92-
// Reaching here means CompareVersions>=0, make sure err is PI_SUCCESS.
83+
84+
// TODO: Compare Supported versions and check for backward compatibility.
85+
// Make sure err is PI_SUCCESS.
9386
assert((err == PI_SUCCESS) && "Unexpected error when binding to Plugin.");
87+
88+
// TODO: Return a more meaningful value/enum.
9489
return true;
9590
}
9691

@@ -150,7 +145,7 @@ void assertion(bool Condition, const char *Message) {
150145
#define _PI_API(api) \
151146
template <> \
152147
Trace<decltype(&::api), \
153-
(offsetof(_pi_plugin::FunctionPointers, api))>::Trace() { \
148+
(offsetof(pi_plugin::FunctionPointers, api))>::Trace() { \
154149
initialize(); \
155150
m_FnPtr = (RT::PluginInformation.PiFunctionTable.api); \
156151
m_FnName = #api; \

sycl/unittests/pi/PlatformTest.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ class PlatformTest : public ::testing::Test {
3232
// Initialize the logged number of platforms before the following assertion.
3333
RecordProperty(platform_count_key, platform_count);
3434

35-
ASSERT_EQ(RT::PluginInformation.PiFunctionTable.piPlatformsGet(
36-
0, 0, &platform_count),
35+
ASSERT_EQ(PI_CALL_NOCHECK(piPlatformsGet)(0, 0, &platform_count),
3736
PI_SUCCESS);
3837

3938
// Overwrite previous log value with queried number of platforms.
@@ -50,8 +49,8 @@ class PlatformTest : public ::testing::Test {
5049

5150
_platforms.resize(platform_count, nullptr);
5251

53-
ASSERT_EQ(RT::PluginInformation.PiFunctionTable.piPlatformsGet(
54-
_platforms.size(), _platforms.data(), nullptr),
52+
ASSERT_EQ(PI_CALL_NOCHECK(piPlatformsGet)(_platforms.size(),
53+
_platforms.data(), nullptr),
5554
PI_SUCCESS);
5655
}
5756
};
@@ -64,13 +63,13 @@ TEST_F(PlatformTest, piPlatformsGet) {
6463
TEST_F(PlatformTest, piPlatformGetInfo) {
6564
auto get_info_test = [](pi_platform platform, _pi_platform_info info) {
6665
size_t reported_string_length = 0;
67-
EXPECT_EQ(RT::PluginInformation.PiFunctionTable.piPlatformGetInfo(
68-
platform, info, 0u, nullptr, &reported_string_length),
66+
EXPECT_EQ(PI_CALL_NOCHECK(piPlatformGetInfo)(platform, info, 0u, nullptr,
67+
&reported_string_length),
6968
PI_SUCCESS);
7069

7170
// Create a larger result string to catch overwrites.
7271
std::vector<char> param_value(reported_string_length * 2u, '\0');
73-
EXPECT_EQ(RT::PluginInformation.PiFunctionTable.piPlatformGetInfo(
72+
EXPECT_EQ(PI_CALL_NOCHECK(piPlatformGetInfo)(
7473
platform, info, param_value.size(), param_value.data(), 0u),
7574
PI_SUCCESS)
7675
<< "piPlatformGetInfo for " << RT::platformInfoToString(info)

0 commit comments

Comments
 (0)