Skip to content

Commit be4e641

Browse files
authored
SYCL] Support specialization constants in the L0 plugin (#1982)
Author: Jaime Arteaga <[email protected]> Signed-off-by: Artur Gainullin <[email protected]>
1 parent a51c333 commit be4e641

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

sycl/plugins/level_zero/pi_level0.cpp

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <string>
2121
#include <thread>
2222
#include <utility>
23+
#include <vector>
2324

2425
#include <level_zero/zet_api.h>
2526

@@ -1824,9 +1825,29 @@ pi_result piProgramBuild(pi_program Program, pi_uint32 NumDevices,
18241825
// Check that the program wasn't already built.
18251826
assert(!Program->ZeModule);
18261827

1828+
// Translate collected specialization constants.
1829+
ze_module_constants_t ZeSpecConstants = {};
1830+
std::vector<uint32_t> ZeSpecContantsIds(Program->ZeSpecConstants.size());
1831+
std::vector<uint64_t> ZeSpecContantsValues(Program->ZeSpecConstants.size());
1832+
{
1833+
std::lock_guard<std::mutex> ZeSpecConstantsMutexGuard(
1834+
Program->ZeSpecConstantsMutex);
1835+
ZeSpecConstants.numConstants = Program->ZeSpecConstants.size();
1836+
auto ZeSpecContantsIdsIt = ZeSpecContantsIds.begin();
1837+
auto ZeSpecContantsValuesIt = ZeSpecContantsValues.begin();
1838+
for (auto &SpecConstant : Program->ZeSpecConstants) {
1839+
*ZeSpecContantsIdsIt = SpecConstant.first;
1840+
ZeSpecContantsIdsIt++;
1841+
*ZeSpecContantsValuesIt = SpecConstant.second;
1842+
ZeSpecContantsValuesIt++;
1843+
}
1844+
ZeSpecConstants.pConstantIds = ZeSpecContantsIds.data();
1845+
ZeSpecConstants.pConstantValues = ZeSpecContantsValues.data();
1846+
}
1847+
1848+
// Complete the module's descriptor
18271849
Program->ZeModuleDesc.pBuildFlags = Options;
1828-
// TODO: set specialization constants here.
1829-
Program->ZeModuleDesc.pConstants = nullptr;
1850+
Program->ZeModuleDesc.pConstants = &ZeSpecConstants;
18301851

18311852
ze_device_handle_t ZeDevice = Program->Context->Device->ZeDevice;
18321853
ZE_CALL(zeModuleCreate(ZeDevice, &Program->ZeModuleDesc, &Program->ZeModule,
@@ -3713,12 +3734,20 @@ pi_result piKernelSetExecInfo(pi_kernel Kernel, pi_kernel_exec_info ParamName,
37133734
}
37143735

37153736
pi_result piextProgramSetSpecializationConstant(pi_program Prog,
3716-
pi_uint32 SpecID,
3717-
size_t SpecSize,
3737+
pi_uint32 SpecID, size_t,
37183738
const void *SpecValue) {
3719-
// TODO: implement
3720-
die("piextProgramSetSpecializationConstant: not implemented");
3721-
return {};
3739+
// Level Zero sets spec constants when creating modules,
3740+
// so save them for when program is built.
3741+
std::lock_guard<std::mutex> ZeSpecConstantsMutexGuard(
3742+
Prog->ZeSpecConstantsMutex);
3743+
3744+
// Pass SpecValue pointer. Spec constant value is retrieved
3745+
// by Level-Zero when creating the modul
3746+
//
3747+
// NOTE: SpecSize is unused in L0, the size is known from SPIR-V by SpecID.
3748+
Prog->ZeSpecConstants[SpecID] = reinterpret_cast<uint64_t>(SpecValue);
3749+
3750+
return PI_SUCCESS;
37223751
}
37233752

37243753
pi_result piPluginInit(pi_plugin *PluginInit) {

sycl/plugins/level_zero/pi_level0.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,9 @@ struct _pi_program : _pi_object {
316316

317317
// L0 module handle.
318318
ze_module_handle_t ZeModule;
319+
// L0 module specialization constants
320+
std::mutex ZeSpecConstantsMutex;
321+
std::unordered_map<uint32_t, uint64_t> ZeSpecConstants;
319322

320323
// L0 build log.
321324
ze_module_build_log_handle_t ZeBuildLog;

0 commit comments

Comments
 (0)