|
20 | 20 | #include <string>
|
21 | 21 | #include <thread>
|
22 | 22 | #include <utility>
|
| 23 | +#include <vector> |
23 | 24 |
|
24 | 25 | #include <level_zero/zet_api.h>
|
25 | 26 |
|
@@ -1824,9 +1825,29 @@ pi_result piProgramBuild(pi_program Program, pi_uint32 NumDevices,
|
1824 | 1825 | // Check that the program wasn't already built.
|
1825 | 1826 | assert(!Program->ZeModule);
|
1826 | 1827 |
|
| 1828 | + // Translate collected specialization constants. |
| 1829 | + ze_module_constants_t ZeSpecConstants = {}; |
| 1830 | + std::vector<uint32_t> ZeSpecContantsIds(Program->ZeSpecConstants.size()); |
| 1831 | + std::vector<uint64_t> ZeSpecContantsValues(Program->ZeSpecConstants.size()); |
| 1832 | + { |
| 1833 | + std::lock_guard<std::mutex> ZeSpecConstantsMutexGuard( |
| 1834 | + Program->ZeSpecConstantsMutex); |
| 1835 | + ZeSpecConstants.numConstants = Program->ZeSpecConstants.size(); |
| 1836 | + auto ZeSpecContantsIdsIt = ZeSpecContantsIds.begin(); |
| 1837 | + auto ZeSpecContantsValuesIt = ZeSpecContantsValues.begin(); |
| 1838 | + for (auto &SpecConstant : Program->ZeSpecConstants) { |
| 1839 | + *ZeSpecContantsIdsIt = SpecConstant.first; |
| 1840 | + ZeSpecContantsIdsIt++; |
| 1841 | + *ZeSpecContantsValuesIt = SpecConstant.second; |
| 1842 | + ZeSpecContantsValuesIt++; |
| 1843 | + } |
| 1844 | + ZeSpecConstants.pConstantIds = ZeSpecContantsIds.data(); |
| 1845 | + ZeSpecConstants.pConstantValues = ZeSpecContantsValues.data(); |
| 1846 | + } |
| 1847 | + |
| 1848 | + // Complete the module's descriptor |
1827 | 1849 | Program->ZeModuleDesc.pBuildFlags = Options;
|
1828 |
| - // TODO: set specialization constants here. |
1829 |
| - Program->ZeModuleDesc.pConstants = nullptr; |
| 1850 | + Program->ZeModuleDesc.pConstants = &ZeSpecConstants; |
1830 | 1851 |
|
1831 | 1852 | ze_device_handle_t ZeDevice = Program->Context->Device->ZeDevice;
|
1832 | 1853 | ZE_CALL(zeModuleCreate(ZeDevice, &Program->ZeModuleDesc, &Program->ZeModule,
|
@@ -3713,12 +3734,20 @@ pi_result piKernelSetExecInfo(pi_kernel Kernel, pi_kernel_exec_info ParamName,
|
3713 | 3734 | }
|
3714 | 3735 |
|
3715 | 3736 | pi_result piextProgramSetSpecializationConstant(pi_program Prog,
|
3716 |
| - pi_uint32 SpecID, |
3717 |
| - size_t SpecSize, |
| 3737 | + pi_uint32 SpecID, size_t, |
3718 | 3738 | const void *SpecValue) {
|
3719 |
| - // TODO: implement |
3720 |
| - die("piextProgramSetSpecializationConstant: not implemented"); |
3721 |
| - return {}; |
| 3739 | + // Level Zero sets spec constants when creating modules, |
| 3740 | + // so save them for when program is built. |
| 3741 | + std::lock_guard<std::mutex> ZeSpecConstantsMutexGuard( |
| 3742 | + Prog->ZeSpecConstantsMutex); |
| 3743 | + |
| 3744 | + // Pass SpecValue pointer. Spec constant value is retrieved |
| 3745 | + // by Level-Zero when creating the modul |
| 3746 | + // |
| 3747 | + // NOTE: SpecSize is unused in L0, the size is known from SPIR-V by SpecID. |
| 3748 | + Prog->ZeSpecConstants[SpecID] = reinterpret_cast<uint64_t>(SpecValue); |
| 3749 | + |
| 3750 | + return PI_SUCCESS; |
3722 | 3751 | }
|
3723 | 3752 |
|
3724 | 3753 | pi_result piPluginInit(pi_plugin *PluginInit) {
|
|
0 commit comments