@@ -55,6 +55,26 @@ checkUnresolvedSymbols(ze_module_handle_t ZeModule,
55
55
}
56
56
} // extern "C"
57
57
58
+ static ur_program_handle_t_::CodeFormat matchILCodeFormat (const void *Input,
59
+ size_t Length) {
60
+ const auto MatchMagicNumber = [&](uint32_t Number) {
61
+ return Length >= sizeof (Number) &&
62
+ std::memcmp (Input, &Number, sizeof (Number)) == 0 ;
63
+ };
64
+
65
+ // SPIR-V Specification: 3.1 Magic Number
66
+ // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#Magic
67
+ if (MatchMagicNumber (0x07230203 )) {
68
+ return ur_program_handle_t_::CodeFormat::SPIRV;
69
+ }
70
+
71
+ return ur_program_handle_t_::CodeFormat::Unknown;
72
+ }
73
+
74
+ static bool isCodeFormatIL (ur_program_handle_t_::CodeFormat CodeFormat) {
75
+ return CodeFormat == ur_program_handle_t_::CodeFormat::SPIRV;
76
+ }
77
+
58
78
namespace ur ::level_zero {
59
79
60
80
ur_result_t urProgramCreateWithIL (
@@ -70,9 +90,12 @@ ur_result_t urProgramCreateWithIL(
70
90
ur_program_handle_t *Program) {
71
91
UR_ASSERT (Context, UR_RESULT_ERROR_INVALID_NULL_HANDLE);
72
92
UR_ASSERT (IL && Program, UR_RESULT_ERROR_INVALID_NULL_POINTER);
93
+ const ur_program_handle_t_::CodeFormat CodeFormat =
94
+ matchILCodeFormat (IL, Length);
95
+ UR_ASSERT (isCodeFormatIL (CodeFormat), UR_RESULT_ERROR_INVALID_BINARY);
73
96
try {
74
- ur_program_handle_t_ *UrProgram =
75
- new ur_program_handle_t_ (ur_program_handle_t_ ::IL, Context, IL, Length);
97
+ ur_program_handle_t_ *UrProgram = new ur_program_handle_t_ (
98
+ ur_program_handle_t_::IL, Context, IL, Length, CodeFormat );
76
99
*Program = reinterpret_cast <ur_program_handle_t >(UrProgram);
77
100
} catch (const std::bad_alloc &) {
78
101
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
@@ -195,9 +218,17 @@ ur_result_t urProgramBuildExp(
195
218
auto Code = hProgram->getCode (ZeDevice);
196
219
UR_ASSERT (Code, UR_RESULT_ERROR_INVALID_PROGRAM);
197
220
198
- ZeModuleDesc.format = (State == ur_program_handle_t_::IL)
199
- ? ZE_MODULE_FORMAT_IL_SPIRV
200
- : ZE_MODULE_FORMAT_NATIVE;
221
+ switch (hProgram->getCodeFormat (ZeDevice)) {
222
+ case ur_program_handle_t_::CodeFormat::SPIRV:
223
+ ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
224
+ break ;
225
+ case ur_program_handle_t_::CodeFormat::Native:
226
+ ZeModuleDesc.format = ZE_MODULE_FORMAT_NATIVE;
227
+ break ;
228
+ default :
229
+ ur::unreachable ();
230
+ return UR_RESULT_ERROR_INVALID_PROGRAM;
231
+ }
201
232
ZeModuleDesc.inputSize = hProgram->getCodeSize (ZeDevice);
202
233
ZeModuleDesc.pInputModule = Code;
203
234
ze_context_handle_t ZeContext = hProgram->Context ->getZeHandle ();
@@ -364,6 +395,8 @@ ur_result_t urProgramLinkExp(
364
395
// locks simultaneously with "exclusive" access. However, there is no such
365
396
// code like that, so this is also not a danger.
366
397
std::vector<std::shared_lock<ur_shared_mutex>> Guards (count);
398
+ const ur_program_handle_t_::CodeFormat CommonCodeFormat =
399
+ phPrograms[0 ]->getCodeFormat ();
367
400
for (uint32_t I = 0 ; I < count; I++) {
368
401
std::shared_lock<ur_shared_mutex> Guard (phPrograms[I]->Mutex );
369
402
Guards[I].swap (Guard);
@@ -374,6 +407,13 @@ ur_result_t urProgramLinkExp(
374
407
return UR_RESULT_ERROR_INVALID_OPERATION;
375
408
}
376
409
}
410
+
411
+ // The L0 API has no way to represent mixed format modules,
412
+ // even though it could be possible to implement linking
413
+ // of mixed format modules.
414
+ if (phPrograms[I]->getCodeFormat () != CommonCodeFormat) {
415
+ return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
416
+ }
377
417
}
378
418
379
419
// Previous calls to urProgramCompile did not actually compile the SPIR-V.
@@ -406,7 +446,14 @@ ur_result_t urProgramLinkExp(
406
446
407
447
ZeStruct<ze_module_desc_t > ZeModuleDesc;
408
448
ZeModuleDesc.pNext = &ZeExtModuleDesc;
409
- ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
449
+ switch (CommonCodeFormat) {
450
+ case ur_program_handle_t_::CodeFormat::SPIRV:
451
+ ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
452
+ break ;
453
+ default :
454
+ ur::unreachable ();
455
+ return UR_RESULT_ERROR_INVALID_PROGRAM;
456
+ }
410
457
411
458
// This works around a bug in the Level Zero driver. When "ZE_DEBUG=-1",
412
459
// the driver does validation of the API calls, and it expects
@@ -996,11 +1043,13 @@ ur_result_t urProgramSetSpecializationConstants(
996
1043
997
1044
ur_program_handle_t_::ur_program_handle_t_ (state St,
998
1045
ur_context_handle_t Context,
999
- const void *Input, size_t Length)
1046
+ const void *Input, size_t Length,
1047
+ CodeFormat CodeFormat)
1000
1048
: Context{Context}, NativeProperties{nullptr }, OwnZeModule{true },
1001
- AssociatedDevices (Context->getDevices ()), SpirvCode{new uint8_t [Length]},
1002
- SpirvCodeLength{Length} {
1003
- std::memcpy (SpirvCode.get (), Input, Length);
1049
+ AssociatedDevices (Context->getDevices ()), ILCode{new uint8_t [Length]},
1050
+ ILCodeLength{Length}, ILCodeFormat(CodeFormat) {
1051
+ assert (isCodeFormatIL (CodeFormat));
1052
+ std::memcpy (ILCode.get (), Input, Length);
1004
1053
// All devices have the program in IL state.
1005
1054
for (auto &Device : Context->getDevices ()) {
1006
1055
DeviceData &PerDevData = DeviceDataMap[Device->ZeDevice ];
0 commit comments