Skip to content

Commit 30cfe3e

Browse files
author
sergei
authored
[SYCL] Avoid copy of plugin information (#5049)
Reuse plugin information so that it's not copied when sycl::detail::plugin is copied
1 parent 9ca7cea commit 30cfe3e

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

sycl/source/detail/pi.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -346,14 +346,15 @@ int unloadPlugin(void *Library) { return unloadOsLibrary(Library); }
346346
// call is done to get all Interface API mapping. The plugin interface also
347347
// needs to setup infrastructure to route PI_CALLs to the appropriate plugins.
348348
// Currently, we bind to a singe plugin.
349-
bool bindPlugin(void *Library, PiPlugin *PluginInformation) {
349+
bool bindPlugin(void *Library,
350+
const std::shared_ptr<PiPlugin> &PluginInformation) {
350351

351352
decltype(::piPluginInit) *PluginInitializeFunction = (decltype(
352353
&::piPluginInit))(getOsLibraryFuncAddress(Library, "piPluginInit"));
353354
if (PluginInitializeFunction == nullptr)
354355
return false;
355356

356-
int Err = PluginInitializeFunction(PluginInformation);
357+
int Err = PluginInitializeFunction(PluginInformation.get());
357358

358359
// TODO: Compare Supported versions and check for backward compatibility.
359360
// Make sure err is PI_SUCCESS.
@@ -387,11 +388,11 @@ static void initializePlugins(std::vector<plugin> &Plugins) {
387388
std::cerr << "SYCL_PI_TRACE[all]: "
388389
<< "No Plugins Found." << std::endl;
389390

390-
PiPlugin PluginInformation{
391-
_PI_H_VERSION_STRING, _PI_H_VERSION_STRING, nullptr, {}};
392-
PluginInformation.PiFunctionTable = {};
393-
394391
for (unsigned int I = 0; I < PluginNames.size(); I++) {
392+
std::shared_ptr<PiPlugin> PluginInformation = std::make_shared<PiPlugin>(
393+
PiPlugin{_PI_H_VERSION_STRING, _PI_H_VERSION_STRING,
394+
/*Targets=*/nullptr, /*FunctionPointers=*/{}});
395+
395396
void *Library = loadPlugin(PluginNames[I].first);
396397

397398
if (!Library) {
@@ -404,7 +405,7 @@ static void initializePlugins(std::vector<plugin> &Plugins) {
404405
continue;
405406
}
406407

407-
if (!bindPlugin(Library, &PluginInformation)) {
408+
if (!bindPlugin(Library, PluginInformation)) {
408409
if (trace(PI_TRACE_ALL)) {
409410
std::cerr << "SYCL_PI_TRACE[all]: "
410411
<< "Failed to bind PI APIs to the plugin: "

sycl/source/detail/plugin.hpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ auto packCallArguments(ArgsT &&... Args) {
8989
class plugin {
9090
public:
9191
plugin() = delete;
92-
plugin(RT::PiPlugin Plugin, backend UseBackend, void *LibraryHandle)
92+
plugin(const std::shared_ptr<RT::PiPlugin> &Plugin, backend UseBackend,
93+
void *LibraryHandle)
9394
: MPlugin(Plugin), MBackend(UseBackend), MLibraryHandle(LibraryHandle),
9495
TracingMutex(std::make_shared<std::mutex>()),
9596
MPluginMutex(std::make_shared<std::mutex>()) {}
@@ -101,8 +102,11 @@ class plugin {
101102

102103
~plugin() = default;
103104

104-
const RT::PiPlugin &getPiPlugin() const { return MPlugin; }
105-
RT::PiPlugin &getPiPlugin() { return MPlugin; }
105+
const RT::PiPlugin &getPiPlugin() const { return *MPlugin; }
106+
RT::PiPlugin &getPiPlugin() { return *MPlugin; }
107+
const std::shared_ptr<RT::PiPlugin> &getPiPluginPtr() const {
108+
return MPlugin;
109+
}
106110

107111
/// Checks return value from PI calls.
108112
///
@@ -148,29 +152,30 @@ class plugin {
148152
uint64_t CorrelationID = pi::emitFunctionBeginTrace(PIFnName);
149153
auto ArgsData =
150154
packCallArguments<PiApiOffset>(std::forward<ArgsT>(Args)...);
151-
uint64_t CorrelationIDWithArgs = pi::emitFunctionWithArgsBeginTrace(
152-
static_cast<uint32_t>(PiApiOffset), PIFnName, ArgsData.data(), MPlugin);
155+
uint64_t CorrelationIDWithArgs =
156+
pi::emitFunctionWithArgsBeginTrace(static_cast<uint32_t>(PiApiOffset),
157+
PIFnName, ArgsData.data(), *MPlugin);
153158
#endif
154159
RT::PiResult R;
155160
if (pi::trace(pi::TraceLevel::PI_TRACE_CALLS)) {
156161
std::lock_guard<std::mutex> Guard(*TracingMutex);
157162
const char *FnName = PiCallInfo.getFuncName();
158163
std::cout << "---> " << FnName << "(" << std::endl;
159164
RT::printArgs(Args...);
160-
R = PiCallInfo.getFuncPtr(MPlugin)(Args...);
165+
R = PiCallInfo.getFuncPtr(*MPlugin)(Args...);
161166
std::cout << ") ---> ";
162167
RT::printArgs(R);
163168
RT::printOuts(Args...);
164169
std::cout << std::endl;
165170
} else {
166-
R = PiCallInfo.getFuncPtr(MPlugin)(Args...);
171+
R = PiCallInfo.getFuncPtr(*MPlugin)(Args...);
167172
}
168173
#ifdef XPTI_ENABLE_INSTRUMENTATION
169174
// Close the function begin with a call to function end
170175
pi::emitFunctionEndTrace(CorrelationID, PIFnName);
171176
pi::emitFunctionWithArgsEndTrace(CorrelationIDWithArgs,
172177
static_cast<uint32_t>(PiApiOffset),
173-
PIFnName, ArgsData.data(), R, MPlugin);
178+
PIFnName, ArgsData.data(), R, *MPlugin);
174179
#endif
175180
return R;
176181
}
@@ -236,7 +241,7 @@ class plugin {
236241
std::shared_ptr<std::mutex> getPluginMutex() { return MPluginMutex; }
237242

238243
private:
239-
RT::PiPlugin MPlugin;
244+
std::shared_ptr<RT::PiPlugin> MPlugin;
240245
backend MBackend;
241246
void *MLibraryHandle; // the handle returned from dlopen
242247
std::shared_ptr<std::mutex> TracingMutex;

sycl/unittests/helpers/PiMock.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ class PiMock {
123123
// Copy the PiPlugin, thus untying our to-be mock platform from other
124124
// platforms within the context. Reset our platform to use the new plugin.
125125
auto NewPluginPtr = std::make_shared<detail::plugin>(
126-
OriginalPiPlugin.getPiPlugin(), OriginalPiPlugin.getBackend(),
126+
OriginalPiPlugin.getPiPluginPtr(), OriginalPiPlugin.getBackend(),
127127
OriginalPiPlugin.getLibraryHandle());
128128
ImplPtr->setPlugin(NewPluginPtr);
129129
// Extract the new PiPlugin instance by a non-const pointer,

0 commit comments

Comments
 (0)