Skip to content

Commit 4e535c8

Browse files
authored
[UR][L0] Refactor IL code handling allowing future extension (#18441)
Refactor the code to make it easier to add support for different IL formats (besides SPIR-V) in the future. The only functional change is that SPIR-V binaries with invalid magic number are now rejected by returning UR_RESULT_ERROR_INVALID_BINARY from urProgramCreateWithIL.
1 parent 54438a2 commit 4e535c8

File tree

2 files changed

+87
-17
lines changed

2 files changed

+87
-17
lines changed

unified-runtime/source/adapters/level_zero/program.cpp

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ checkUnresolvedSymbols(ze_module_handle_t ZeModule,
5555
}
5656
} // extern "C"
5757

58+
static ur_program_handle_t_::CodeFormat matchILCodeFormat(const void *Input,
59+
size_t Length) {
60+
const auto MatchMagicNumber = [&](uint32_t Number) {
61+
return Length >= sizeof(Number) &&
62+
std::memcmp(Input, &Number, sizeof(Number)) == 0;
63+
};
64+
65+
// SPIR-V Specification: 3.1 Magic Number
66+
// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Magic
67+
if (MatchMagicNumber(0x07230203)) {
68+
return ur_program_handle_t_::CodeFormat::SPIRV;
69+
}
70+
71+
return ur_program_handle_t_::CodeFormat::Unknown;
72+
}
73+
74+
static bool isCodeFormatIL(ur_program_handle_t_::CodeFormat CodeFormat) {
75+
return CodeFormat == ur_program_handle_t_::CodeFormat::SPIRV;
76+
}
77+
5878
namespace ur::level_zero {
5979

6080
ur_result_t urProgramCreateWithIL(
@@ -70,9 +90,12 @@ ur_result_t urProgramCreateWithIL(
7090
ur_program_handle_t *Program) {
7191
UR_ASSERT(Context, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
7292
UR_ASSERT(IL && Program, UR_RESULT_ERROR_INVALID_NULL_POINTER);
93+
const ur_program_handle_t_::CodeFormat CodeFormat =
94+
matchILCodeFormat(IL, Length);
95+
UR_ASSERT(isCodeFormatIL(CodeFormat), UR_RESULT_ERROR_INVALID_BINARY);
7396
try {
74-
ur_program_handle_t_ *UrProgram =
75-
new ur_program_handle_t_(ur_program_handle_t_::IL, Context, IL, Length);
97+
ur_program_handle_t_ *UrProgram = new ur_program_handle_t_(
98+
ur_program_handle_t_::IL, Context, IL, Length, CodeFormat);
7699
*Program = reinterpret_cast<ur_program_handle_t>(UrProgram);
77100
} catch (const std::bad_alloc &) {
78101
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
@@ -195,9 +218,17 @@ ur_result_t urProgramBuildExp(
195218
auto Code = hProgram->getCode(ZeDevice);
196219
UR_ASSERT(Code, UR_RESULT_ERROR_INVALID_PROGRAM);
197220

198-
ZeModuleDesc.format = (State == ur_program_handle_t_::IL)
199-
? ZE_MODULE_FORMAT_IL_SPIRV
200-
: ZE_MODULE_FORMAT_NATIVE;
221+
switch (hProgram->getCodeFormat(ZeDevice)) {
222+
case ur_program_handle_t_::CodeFormat::SPIRV:
223+
ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
224+
break;
225+
case ur_program_handle_t_::CodeFormat::Native:
226+
ZeModuleDesc.format = ZE_MODULE_FORMAT_NATIVE;
227+
break;
228+
default:
229+
ur::unreachable();
230+
return UR_RESULT_ERROR_INVALID_PROGRAM;
231+
}
201232
ZeModuleDesc.inputSize = hProgram->getCodeSize(ZeDevice);
202233
ZeModuleDesc.pInputModule = Code;
203234
ze_context_handle_t ZeContext = hProgram->Context->getZeHandle();
@@ -364,6 +395,8 @@ ur_result_t urProgramLinkExp(
364395
// locks simultaneously with "exclusive" access. However, there is no such
365396
// code like that, so this is also not a danger.
366397
std::vector<std::shared_lock<ur_shared_mutex>> Guards(count);
398+
const ur_program_handle_t_::CodeFormat CommonCodeFormat =
399+
phPrograms[0]->getCodeFormat();
367400
for (uint32_t I = 0; I < count; I++) {
368401
std::shared_lock<ur_shared_mutex> Guard(phPrograms[I]->Mutex);
369402
Guards[I].swap(Guard);
@@ -374,6 +407,13 @@ ur_result_t urProgramLinkExp(
374407
return UR_RESULT_ERROR_INVALID_OPERATION;
375408
}
376409
}
410+
411+
// The L0 API has no way to represent mixed format modules,
412+
// even though it could be possible to implement linking
413+
// of mixed format modules.
414+
if (phPrograms[I]->getCodeFormat() != CommonCodeFormat) {
415+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
416+
}
377417
}
378418

379419
// Previous calls to urProgramCompile did not actually compile the SPIR-V.
@@ -406,7 +446,14 @@ ur_result_t urProgramLinkExp(
406446

407447
ZeStruct<ze_module_desc_t> ZeModuleDesc;
408448
ZeModuleDesc.pNext = &ZeExtModuleDesc;
409-
ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
449+
switch (CommonCodeFormat) {
450+
case ur_program_handle_t_::CodeFormat::SPIRV:
451+
ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
452+
break;
453+
default:
454+
ur::unreachable();
455+
return UR_RESULT_ERROR_INVALID_PROGRAM;
456+
}
410457

411458
// This works around a bug in the Level Zero driver. When "ZE_DEBUG=-1",
412459
// the driver does validation of the API calls, and it expects
@@ -996,11 +1043,13 @@ ur_result_t urProgramSetSpecializationConstants(
9961043

9971044
ur_program_handle_t_::ur_program_handle_t_(state St,
9981045
ur_context_handle_t Context,
999-
const void *Input, size_t Length)
1046+
const void *Input, size_t Length,
1047+
CodeFormat CodeFormat)
10001048
: Context{Context}, NativeProperties{nullptr}, OwnZeModule{true},
1001-
AssociatedDevices(Context->getDevices()), SpirvCode{new uint8_t[Length]},
1002-
SpirvCodeLength{Length} {
1003-
std::memcpy(SpirvCode.get(), Input, Length);
1049+
AssociatedDevices(Context->getDevices()), ILCode{new uint8_t[Length]},
1050+
ILCodeLength{Length}, ILCodeFormat(CodeFormat) {
1051+
assert(isCodeFormatIL(CodeFormat));
1052+
std::memcpy(ILCode.get(), Input, Length);
10041053
// All devices have the program in IL state.
10051054
for (auto &Device : Context->getDevices()) {
10061055
DeviceData &PerDevData = DeviceDataMap[Device->ZeDevice];

unified-runtime/source/adapters/level_zero/program.hpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ struct ur_program_handle_t_ : ur_object {
4141
Invalid
4242
} state;
4343

44+
enum class CodeFormat : uint8_t {
45+
Native,
46+
SPIRV,
47+
Unknown,
48+
};
49+
4450
// A utility class that converts specialization constants into the form
4551
// required by the Level Zero driver.
4652
class SpecConstantShim {
@@ -68,7 +74,7 @@ struct ur_program_handle_t_ : ur_object {
6874

6975
// Construct a program in IL.
7076
ur_program_handle_t_(state St, ur_context_handle_t Context, const void *Input,
71-
size_t Length);
77+
size_t Length, CodeFormat Format);
7278

7379
// Construct a program in NATIVE for multiple devices.
7480
ur_program_handle_t_(state St, ur_context_handle_t Context,
@@ -113,28 +119,42 @@ struct ur_program_handle_t_ : ur_object {
113119
return DeviceDataMap[ZeDevice].ZeModule;
114120
}
115121

122+
CodeFormat getCodeFormat(ze_device_handle_t ZeDevice = nullptr) const {
123+
if (!ZeDevice)
124+
return ILCodeFormat;
125+
126+
auto It = DeviceDataMap.find(ZeDevice);
127+
if (It == DeviceDataMap.end())
128+
return ILCodeFormat;
129+
130+
if (It->second.State == state::IL)
131+
return ILCodeFormat;
132+
else
133+
return CodeFormat::Native;
134+
}
135+
116136
uint8_t *getCode(ze_device_handle_t ZeDevice = nullptr) {
117137
if (!ZeDevice)
118-
return SpirvCode.get();
138+
return ILCode.get();
119139

120140
if (DeviceDataMap.find(ZeDevice) == DeviceDataMap.end())
121141
return nullptr;
122142

123143
if (DeviceDataMap[ZeDevice].State == state::IL)
124-
return SpirvCode.get();
144+
return ILCode.get();
125145
else
126146
return DeviceDataMap[ZeDevice].Binary.first.get();
127147
}
128148

129149
size_t getCodeSize(ze_device_handle_t ZeDevice = nullptr) {
130150
if (ZeDevice == nullptr)
131-
return SpirvCodeLength;
151+
return ILCodeLength;
132152

133153
if (DeviceDataMap.find(ZeDevice) == DeviceDataMap.end())
134154
return 0;
135155

136156
if (DeviceDataMap[ZeDevice].State == state::IL)
137-
return SpirvCodeLength;
157+
return ILCodeLength;
138158
else
139159
return DeviceDataMap[ZeDevice].Binary.second;
140160
}
@@ -233,8 +253,9 @@ struct ur_program_handle_t_ : ur_object {
233253

234254
// In IL and Object states, this contains the SPIR-V representation of the
235255
// module.
236-
std::unique_ptr<uint8_t[]> SpirvCode; // Array containing raw IL code.
237-
size_t SpirvCodeLength = 0; // Size (bytes) of the array.
256+
std::unique_ptr<uint8_t[]> ILCode; // Array containing raw IL code.
257+
size_t ILCodeLength = 0; // Size (bytes) of the array.
258+
CodeFormat ILCodeFormat = CodeFormat::Unknown; // Format of the IL code.
238259

239260
// The Level Zero module handle for interoperability.
240261
// This module handle is either initialized with the handle provided to

0 commit comments

Comments
 (0)