Skip to content

Commit 5c30815

Browse files
authored
[SYCL] Don't use legacy ANSI-only Windows API for loading plugins (#10943)
Currently to load PI plugins we use legacy ANSI-only versions of Windows API like GetModuleFileNameA, PathRemoveFileSpecA etc. Problem is that if path containing PI plugins has any non-ANSI symbols then PI plugins are not found and not loaded. In this patch get rid of legacy API calls, for example, use GetModuleFileName instead of GetModuleFileNameA. GetModuleFileName is an alias which automatically selects the ANSI or Unicode version of this function. Another difference is that GetModuleFileName and other similar aliases work with wchar_t to be able to handle unicode on Windows (in contrast to legacy GetModuleFileNameA which works with char_t). So, use std::filesystem:path to work with library paths for convenience (instead of storing path in std::string or std::wstring) because it allows to handle paths without caring about format, can be constructed from string/wstring/.. and can be converted to string/wstring ... DPCPP is supported on some linux systems where default compiler is gcc 7.5 which doesn't provide `<filesystem>` support. On Windows, minimal supported version of Visual Studio is 2019 where `<filesystem`> is available (supported since Visual Studio 2017 version 15.7). That's why use filesystem::path only on Windows for now, added TODO to do the same on Linux when matrix support changes.
1 parent 006b882 commit 5c30815

File tree

7 files changed

+144
-77
lines changed

7 files changed

+144
-77
lines changed

sycl/pi_win_proxy_loader/pi_win_proxy_loader.cpp

Lines changed: 38 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
// similar approach.
2424

2525
#include <cassert>
26+
#include <filesystem>
2627

2728
#ifdef _WIN32
2829

@@ -43,22 +44,6 @@
4344

4445
// ------------------------------------
4546

46-
static constexpr const char *DirSep = "\\";
47-
48-
// cribbed from sycl/source/detail/os_util.cpp
49-
std::string getDirName(const char *Path) {
50-
std::string Tmp(Path);
51-
// Remove trailing directory separators
52-
Tmp.erase(Tmp.find_last_not_of("/\\") + 1, std::string::npos);
53-
54-
size_t pos = Tmp.find_last_of("/\\");
55-
if (pos != std::string::npos)
56-
return Tmp.substr(0, pos);
57-
58-
// If no directory separator is present return initial path like dirname does
59-
return Tmp;
60-
}
61-
6247
// cribbed from sycl/source/detail/os_util.cpp
6348
// TODO: Just inline it.
6449
using OSModuleHandle = intptr_t;
@@ -80,20 +65,18 @@ static OSModuleHandle getOSModuleHandle(const void *VirtAddr) {
8065

8166
// cribbed from sycl/source/detail/os_util.cpp
8267
/// Returns an absolute path where the object was found.
83-
std::string getCurrentDSODir() {
84-
char Path[MAX_PATH];
85-
Path[0] = '\0';
86-
Path[sizeof(Path) - 1] = '\0';
68+
std::wstring getCurrentDSODir() {
69+
wchar_t Path[MAX_PATH];
8770
auto Handle = getOSModuleHandle(reinterpret_cast<void *>(&getCurrentDSODir));
88-
DWORD Ret = GetModuleFileNameA(
89-
reinterpret_cast<HMODULE>(ExeModuleHandle == Handle ? 0 : Handle),
90-
reinterpret_cast<LPSTR>(&Path), sizeof(Path));
71+
DWORD Ret = GetModuleFileName(
72+
reinterpret_cast<HMODULE>(ExeModuleHandle == Handle ? 0 : Handle), Path,
73+
sizeof(Path));
9174
assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?");
92-
assert(Ret > 0 && "GetModuleFileNameA failed");
75+
assert(Ret > 0 && "GetModuleFileName failed");
9376
(void)Ret;
9477

95-
BOOL RetCode = PathRemoveFileSpecA(reinterpret_cast<LPSTR>(&Path));
96-
assert(RetCode && "PathRemoveFileSpecA failed");
78+
BOOL RetCode = PathRemoveFileSpec(Path);
79+
assert(RetCode && "PathRemoveFileSpec failed");
9780
(void)RetCode;
9881

9982
return Path;
@@ -121,7 +104,7 @@ std::string getCurrentDSODir() {
121104

122105
// ------------------------------------
123106

124-
using MapT = std::map<std::string, void *>;
107+
using MapT = std::map<std::filesystem::path, void *>;
125108

126109
MapT &getDllMap() {
127110
static MapT dllMap;
@@ -141,55 +124,60 @@ void preloadLibraries() {
141124
//
142125
UINT SavedMode = SetErrorMode(SEM_FAILCRITICALERRORS);
143126
// Exclude current directory from DLL search path
144-
if (!SetDllDirectoryA("")) {
127+
if (!SetDllDirectory(L"")) {
145128
assert(false && "Failed to update DLL search path");
146129
}
147130

148131
// this path duplicates sycl/detail/pi.cpp:initializePlugins
149-
const std::string LibSYCLDir = getCurrentDSODir() + DirSep;
132+
std::filesystem::path LibSYCLDir(getCurrentDSODir());
150133

151134
MapT &dllMap = getDllMap();
152135

153-
std::string ocl_path = LibSYCLDir + __SYCL_OPENCL_PLUGIN_NAME;
154-
dllMap.emplace(ocl_path, LoadLibraryExA(ocl_path.c_str(), NULL, NULL));
136+
auto ocl_path = LibSYCLDir / __SYCL_OPENCL_PLUGIN_NAME;
137+
dllMap.emplace(ocl_path,
138+
LoadLibraryEx(ocl_path.wstring().c_str(), NULL, NULL));
155139

156-
std::string l0_path = LibSYCLDir + __SYCL_LEVEL_ZERO_PLUGIN_NAME;
157-
dllMap.emplace(l0_path, LoadLibraryExA(l0_path.c_str(), NULL, NULL));
140+
auto l0_path = LibSYCLDir / __SYCL_LEVEL_ZERO_PLUGIN_NAME;
141+
dllMap.emplace(l0_path, LoadLibraryEx(l0_path.wstring().c_str(), NULL, NULL));
158142

159-
std::string cuda_path = LibSYCLDir + __SYCL_CUDA_PLUGIN_NAME;
160-
dllMap.emplace(cuda_path, LoadLibraryExA(cuda_path.c_str(), NULL, NULL));
143+
auto cuda_path = LibSYCLDir / __SYCL_CUDA_PLUGIN_NAME;
144+
dllMap.emplace(cuda_path,
145+
LoadLibraryEx(cuda_path.wstring().c_str(), NULL, NULL));
161146

162-
std::string esimd_path = LibSYCLDir + __SYCL_ESIMD_EMULATOR_PLUGIN_NAME;
163-
dllMap.emplace(esimd_path, LoadLibraryExA(esimd_path.c_str(), NULL, NULL));
147+
auto esimd_path = LibSYCLDir / __SYCL_ESIMD_EMULATOR_PLUGIN_NAME;
148+
dllMap.emplace(esimd_path,
149+
LoadLibraryEx(esimd_path.wstring().c_str(), NULL, NULL));
164150

165-
std::string hip_path = LibSYCLDir + __SYCL_HIP_PLUGIN_NAME;
166-
dllMap.emplace(hip_path, LoadLibraryExA(hip_path.c_str(), NULL, NULL));
151+
auto hip_path = LibSYCLDir / __SYCL_HIP_PLUGIN_NAME;
152+
dllMap.emplace(hip_path,
153+
LoadLibraryEx(hip_path.wstring().c_str(), NULL, NULL));
167154

168-
std::string ur_path = LibSYCLDir + __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME;
169-
dllMap.emplace(ur_path, LoadLibraryExA(ur_path.c_str(), NULL, NULL));
155+
auto ur_path = LibSYCLDir / __SYCL_UNIFIED_RUNTIME_PLUGIN_NAME;
156+
dllMap.emplace(ur_path, LoadLibraryEx(ur_path.wstring().c_str(), NULL, NULL));
170157

171-
std::string nativecpu_path = LibSYCLDir + __SYCL_NATIVE_CPU_PLUGIN_NAME;
158+
auto nativecpu_path = LibSYCLDir / __SYCL_NATIVE_CPU_PLUGIN_NAME;
172159
dllMap.emplace(nativecpu_path,
173-
LoadLibraryExA(nativecpu_path.c_str(), NULL, NULL));
160+
LoadLibraryEx(nativecpu_path.wstring().c_str(), NULL, NULL));
174161

175162
// Restore system error handling.
176163
(void)SetErrorMode(SavedMode);
177-
if (!SetDllDirectoryA(nullptr)) {
164+
if (!SetDllDirectory(nullptr)) {
178165
assert(false && "Failed to restore DLL search path");
179166
}
180167
}
181168

182169
/// windows_pi.cpp:loadOsPluginLibrary() calls this to get the DLL loaded
183170
/// earlier.
184-
__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) {
171+
__declspec(dllexport) void *getPreloadedPlugin(
172+
const std::filesystem::path &PluginPath) {
185173

186174
MapT &dllMap = getDllMap();
187175

188176
auto match = dllMap.find(PluginPath); // result might be nullptr (not found),
189177
// which is perfectly valid.
190178
if (match == dllMap.end()) {
191179
// unit testing? return nullptr (not found) rather than risk asserting below
192-
if (PluginPath.find("unittests") != std::string::npos)
180+
if (PluginPath.string().find("unittests") != std::string::npos)
193181
return nullptr;
194182

195183
// Otherwise, asking for something we don't know about at all, is an issue.
@@ -200,6 +188,10 @@ __declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) {
200188
return match->second;
201189
}
202190

191+
__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath) {
192+
return getPreloadedPlugin(std::filesystem::path(PluginPath));
193+
}
194+
203195
BOOL WINAPI DllMain(HINSTANCE hinstDLL, // handle to DLL module
204196
DWORD fdwReason, // reason for calling function
205197
LPVOID lpReserved) // reserved

sycl/pi_win_proxy_loader/pi_win_proxy_loader.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
#pragma once
1010

1111
#ifdef _WIN32
12+
#include <filesystem>
1213
#include <string>
1314

15+
__declspec(dllexport) void *getPreloadedPlugin(
16+
const std::filesystem::path &PluginPath);
17+
// TODO: Remove this version during ABI breakage window
1418
__declspec(dllexport) void *getPreloadedPlugin(const std::string &PluginPath);
1519
#endif

sycl/source/detail/os_util.cpp

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
#elif defined(__SYCL_RT_OS_WINDOWS)
3131

32+
#include <detail/windows_os_utils.hpp>
33+
3234
#include <Windows.h>
3335
#include <direct.h>
3436
#include <malloc.h>
@@ -139,23 +141,6 @@ std::string OSUtil::getDirName(const char *Path) {
139141
}
140142

141143
#elif defined(__SYCL_RT_OS_WINDOWS)
142-
// TODO: Just inline it.
143-
using OSModuleHandle = intptr_t;
144-
static constexpr OSModuleHandle ExeModuleHandle = -1;
145-
static OSModuleHandle getOSModuleHandle(const void *VirtAddr) {
146-
HMODULE PhModule;
147-
DWORD Flag = GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
148-
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT;
149-
auto LpModuleAddr = reinterpret_cast<LPCSTR>(VirtAddr);
150-
if (!GetModuleHandleExA(Flag, LpModuleAddr, &PhModule)) {
151-
// Expect the caller to check for zero and take
152-
// necessary action
153-
return 0;
154-
}
155-
if (PhModule == GetModuleHandleA(nullptr))
156-
return ExeModuleHandle;
157-
return reinterpret_cast<OSModuleHandle>(PhModule);
158-
}
159144

160145
/// Returns an absolute path where the object was found.
161146
// pi_win_proxy_loader.dll uses this same logic. If it is changed

sycl/source/detail/pi.cpp

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <sstream>
3232
#include <stddef.h>
3333
#include <string>
34+
#include <tuple>
3435

3536
#ifdef XPTI_ENABLE_INSTRUMENTATION
3637
// Include the headers necessary for emitting
@@ -435,47 +436,53 @@ std::vector<PluginPtr> &initialize() {
435436
return GlobalHandler::instance().getPlugins();
436437
}
437438

439+
// Implementation of this function is OS specific. Please see windows_pi.cpp and
440+
// posix_pi.cpp.
441+
// TODO: refactor code when support matrix for DPCPP changes and <filesystem> is
442+
// available on all supported systems.
443+
std::vector<std::tuple<std::string, backend, void *>>
444+
loadPlugins(const std::vector<std::pair<std::string, backend>> &&PluginNames);
445+
438446
static void initializePlugins(std::vector<PluginPtr> &Plugins) {
439-
std::vector<std::pair<std::string, backend>> PluginNames = findPlugins();
447+
const std::vector<std::pair<std::string, backend>> PluginNames =
448+
findPlugins();
440449

441450
if (PluginNames.empty() && trace(PI_TRACE_ALL))
442451
std::cerr << "SYCL_PI_TRACE[all]: "
443452
<< "No Plugins Found." << std::endl;
444453

445-
const std::string LibSYCLDir =
446-
sycl::detail::OSUtil::getCurrentDSODir() + sycl::detail::OSUtil::DirSep;
454+
// Get library handles for the list of plugins.
455+
std::vector<std::tuple<std::string, backend, void *>> LoadedPlugins =
456+
loadPlugins(std::move(PluginNames));
447457

448-
for (unsigned int I = 0; I < PluginNames.size(); I++) {
458+
for (auto [Name, Backend, Library] : LoadedPlugins) {
449459
std::shared_ptr<PiPlugin> PluginInformation = std::make_shared<PiPlugin>(
450460
PiPlugin{_PI_H_VERSION_STRING, _PI_H_VERSION_STRING,
451461
/*Targets=*/nullptr, /*FunctionPointers=*/{}});
452462

453-
void *Library = loadPlugin(LibSYCLDir + PluginNames[I].first);
454-
455463
if (!Library) {
456464
if (trace(PI_TRACE_ALL)) {
457465
std::cerr << "SYCL_PI_TRACE[all]: "
458466
<< "Check if plugin is present. "
459-
<< "Failed to load plugin: " << PluginNames[I].first
460-
<< std::endl;
467+
<< "Failed to load plugin: " << Name << std::endl;
461468
}
462469
continue;
463470
}
464471

465472
if (!bindPlugin(Library, PluginInformation)) {
466473
if (trace(PI_TRACE_ALL)) {
467474
std::cerr << "SYCL_PI_TRACE[all]: "
468-
<< "Failed to bind PI APIs to the plugin: "
469-
<< PluginNames[I].first << std::endl;
475+
<< "Failed to bind PI APIs to the plugin: " << Name
476+
<< std::endl;
470477
}
471478
continue;
472479
}
473-
PluginPtr &NewPlugin = Plugins.emplace_back(std::make_shared<plugin>(
474-
PluginInformation, PluginNames[I].second, Library));
480+
PluginPtr &NewPlugin = Plugins.emplace_back(
481+
std::make_shared<plugin>(PluginInformation, Backend, Library));
475482
if (trace(TraceLevel::PI_TRACE_BASIC))
476483
std::cerr << "SYCL_PI_TRACE[basic]: "
477-
<< "Plugin found and successfully loaded: "
478-
<< PluginNames[I].first << " [ PluginVersion: "
484+
<< "Plugin found and successfully loaded: " << Name
485+
<< " [ PluginVersion: "
479486
<< NewPlugin->getPiPlugin().PluginVersion << " ]" << std::endl;
480487
}
481488

sycl/source/detail/posix_pi.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) {
4848
return dlsym(Library, FunctionName.c_str());
4949
}
5050

51+
// Load plugins corresponding to provided list of plugin names.
52+
std::vector<std::tuple<std::string, backend, void *>>
53+
loadPlugins(const std::vector<std::pair<std::string, backend>> &&PluginNames) {
54+
std::vector<std::tuple<std::string, backend, void *>> LoadedPlugins;
55+
const std::string LibSYCLDir =
56+
sycl::detail::OSUtil::getCurrentDSODir() + sycl::detail::OSUtil::DirSep;
57+
58+
for (auto &PluginName : PluginNames) {
59+
void *Library = loadOsPluginLibrary(LibSYCLDir + PluginName.first);
60+
LoadedPlugins.push_back(std::make_tuple(
61+
std::move(PluginName.first), std::move(PluginName.second), Library));
62+
}
63+
64+
return LoadedPlugins;
65+
}
66+
5167
} // namespace detail::pi
5268
} // namespace _V1
5369
} // namespace sycl
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//==-- windows_os_utils.hpp - Header file with common utils for Windows --==//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===--------------------------------------------------------------------===//
8+
9+
#pragma once
10+
11+
#include <shlwapi.h>
12+
13+
using OSModuleHandle = intptr_t;
14+
constexpr OSModuleHandle ExeModuleHandle = -1;
15+
inline OSModuleHandle getOSModuleHandle(const void *VirtAddr) {
16+
HMODULE PhModule;
17+
DWORD Flag = GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
18+
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT;
19+
auto LpModuleAddr = reinterpret_cast<LPCSTR>(VirtAddr);
20+
if (!GetModuleHandleExA(Flag, LpModuleAddr, &PhModule)) {
21+
// Expect the caller to check for zero and take
22+
// necessary action
23+
return 0;
24+
}
25+
if (PhModule == GetModuleHandleA(nullptr))
26+
return ExeModuleHandle;
27+
return reinterpret_cast<OSModuleHandle>(PhModule);
28+
}

sycl/source/detail/windows_pi.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include <sycl/backend.hpp>
910
#include <sycl/detail/defines.hpp>
1011

1112
#include <cassert>
1213
#include <string>
1314
#include <windows.h>
1415
#include <winreg.h>
1516

17+
#include "detail/windows_os_utils.hpp"
1618
#include "pi_win_proxy_loader.hpp"
1719

1820
namespace sycl {
@@ -66,6 +68,39 @@ void *getOsLibraryFuncAddress(void *Library, const std::string &FunctionName) {
6668
GetProcAddress((HMODULE)Library, FunctionName.c_str()));
6769
}
6870

71+
static std::filesystem::path getCurrentDSODirPath() {
72+
wchar_t Path[MAX_PATH];
73+
auto Handle =
74+
getOSModuleHandle(reinterpret_cast<void *>(&getCurrentDSODirPath));
75+
DWORD Ret = GetModuleFileName(
76+
reinterpret_cast<HMODULE>(ExeModuleHandle == Handle ? 0 : Handle), Path,
77+
sizeof(Path));
78+
assert(Ret < sizeof(Path) && "Path is longer than PATH_MAX?");
79+
assert(Ret > 0 && "GetModuleFileName failed");
80+
(void)Ret;
81+
82+
BOOL RetCode = PathRemoveFileSpec(Path);
83+
assert(RetCode && "PathRemoveFileSpec failed");
84+
(void)RetCode;
85+
86+
return std::filesystem::path(Path);
87+
}
88+
89+
// Load plugins corresponding to provided list of plugin names.
90+
std::vector<std::tuple<std::string, backend, void *>>
91+
loadPlugins(const std::vector<std::pair<std::string, backend>> &&PluginNames) {
92+
std::vector<std::tuple<std::string, backend, void *>> LoadedPlugins;
93+
const std::filesystem::path LibSYCLDir = getCurrentDSODirPath();
94+
95+
for (auto &PluginName : PluginNames) {
96+
void *Library = getPreloadedPlugin(LibSYCLDir / PluginName.first);
97+
LoadedPlugins.push_back(std::make_tuple(
98+
std::move(PluginName.first), std::move(PluginName.second), Library));
99+
}
100+
101+
return LoadedPlugins;
102+
}
103+
69104
} // namespace pi
70105
} // namespace detail
71106
} // namespace _V1

0 commit comments

Comments
 (0)