Skip to content

Commit 5ee2b0f

Browse files
committed
Fix test failures with AoT compilation
Programs created via piProgramCreateWithBinary() are always built later via piProgramBuild(). Change piProgramBuild(), so this no longer raises an error. Also fix some other behavior of programs created with piProgramCreateWithBinary(), following the model of the OpenCL spec.
1 parent 7a57bf8 commit 5ee2b0f

File tree

2 files changed

+141
-90
lines changed

2 files changed

+141
-90
lines changed

sycl/plugins/level_zero/pi_level_zero.cpp

Lines changed: 127 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,9 @@ extern "C" {
427427
// Forward declarations
428428
decltype(piEventCreate) piEventCreate;
429429

430+
static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices,
431+
const pi_device *DeviceList,
432+
const char *Options);
430433
static pi_result copyModule(ze_device_handle_t ZeDevice,
431434
ze_module_handle_t SrcMod,
432435
ze_module_handle_t *DestMod);
@@ -1680,15 +1683,19 @@ pi_result piextMemCreateWithNativeHandle(pi_native_handle NativeHandle,
16801683
pi_result piProgramCreate(pi_context Context, const void *ILBytes,
16811684
size_t Length, pi_program *Program) {
16821685

1683-
assert(Context);
1684-
assert(Program);
1686+
if (!Context)
1687+
return PI_INVALID_CONTEXT;
1688+
if (!ILBytes || !Length)
1689+
return PI_INVALID_VALUE;
1690+
if (!Program)
1691+
return PI_INVALID_VALUE;
16851692

16861693
// NOTE: the Level Zero module creation is also building the program, so we
16871694
// are deferring it until the program is ready to be built in piProgramBuild
16881695
// and piProgramCompile. Also it is only then we know the build options.
16891696

16901697
try {
1691-
*Program = new _pi_program(Context, ILBytes, Length);
1698+
*Program = new _pi_program(Context, ILBytes, Length, _pi_program::IL);
16921699
} catch (const std::bad_alloc &) {
16931700
return PI_OUT_OF_HOST_MEMORY;
16941701
} catch (...) {
@@ -1702,46 +1709,44 @@ pi_result piProgramCreateWithBinary(pi_context Context, pi_uint32 NumDevices,
17021709
const size_t *Lengths,
17031710
const unsigned char **Binaries,
17041711
pi_int32 *BinaryStatus,
1705-
pi_program *RetProgram) {
1712+
pi_program *Program) {
17061713

1707-
// This must be for the single device in this context.
1714+
if (!Context)
1715+
return PI_INVALID_CONTEXT;
1716+
if (!DeviceList || !NumDevices)
1717+
return PI_INVALID_VALUE;
1718+
if (!Binaries || !Lengths || !Binaries[0] || !Lengths[0])
1719+
return PI_INVALID_VALUE;
1720+
if (!Program)
1721+
return PI_INVALID_VALUE;
1722+
1723+
// For now we support only one device.
17081724
assert(NumDevices == 1);
1709-
assert(Context);
1710-
assert(RetProgram);
1711-
assert(DeviceList && DeviceList[0] == Context->Device);
1712-
ze_device_handle_t ZeDevice = Context->Device->ZeDevice;
1725+
if (DeviceList[0] != Context->Device)
1726+
return PI_INVALID_DEVICE;
17131727

1714-
// Check the binary too.
1715-
assert(Lengths && Lengths[0] != 0);
1716-
assert(Binaries && Binaries[0] != nullptr);
17171728
size_t Length = Lengths[0];
1718-
auto Binary = pi_cast<const uint8_t *>(Binaries[0]);
1719-
1720-
ze_module_desc_t ZeModuleDesc = {};
1721-
ZeModuleDesc.version = ZE_MODULE_DESC_VERSION_CURRENT;
1722-
ZeModuleDesc.format = ZE_MODULE_FORMAT_NATIVE;
1723-
ZeModuleDesc.inputSize = Length;
1724-
ZeModuleDesc.pInputModule = Binary;
1725-
ZeModuleDesc.pBuildFlags = nullptr;
1726-
ZeModuleDesc.pConstants = nullptr;
1727-
1728-
ze_module_handle_t ZeModule;
1729-
ZE_CALL(zeModuleCreate(ZeDevice, &ZeModuleDesc, &ZeModule, 0));
1729+
auto Binary = Binaries[0];
1730+
1731+
// In OpenCL, clCreateProgramWithBinary() can be used to load any of the
1732+
// following: "program executable", "compiled program", or "library of
1733+
// compiled programs". In addition, the loaded program can be either
1734+
// IL (SPIR-v) or native device code. For now, we assume that
1735+
// piProgramCreateWithBinary() is only used to load a "program executable"
1736+
// as native device code.
1737+
//
1738+
// If we wanted to support all the same cases as OpenCL, we would need to
1739+
// somehow examine the binary image to distinguish the cases. Alternatively,
1740+
// we could change the PI interface and have the caller pass additional
1741+
// information to distinguish the cases.
17301742

17311743
try {
1732-
// TODO: It's not clear whether piProgramCreateWithBinary() can be
1733-
// used also to create programs with state Object. For now, assume
1734-
// it's always a fully linked executable.
1735-
*RetProgram = new _pi_program(Context, ZeModule, _pi_program::Exe);
1744+
*Program = new _pi_program(Context, Binary, Length, _pi_program::Native);
17361745
} catch (const std::bad_alloc &) {
17371746
return PI_OUT_OF_HOST_MEMORY;
17381747
} catch (...) {
17391748
return PI_ERROR_UNKNOWN;
17401749
}
1741-
1742-
if (BinaryStatus) {
1743-
*BinaryStatus = PI_SUCCESS;
1744-
}
17451750
return PI_SUCCESS;
17461751
}
17471752

@@ -1770,11 +1775,9 @@ pi_result piProgramGetInfo(pi_program Program, pi_program_info ParamName,
17701775
return ReturnValue(Program->Context->Device);
17711776
case PI_PROGRAM_INFO_BINARY_SIZES: {
17721777
size_t SzBinary;
1773-
if (Program->State == _pi_program::IL) {
1774-
// The OpenCL spec indicates that PI_PROGRAM_INFO_BINARY_SIZES is not
1775-
// defined in this case, but it does not say what to do if it happens.
1776-
// Returning a zero size seems reasonable.
1777-
SzBinary = 0;
1778+
if (Program->State == _pi_program::IL ||
1779+
Program->State == _pi_program::Native) {
1780+
SzBinary = Program->CodeLength;
17781781
} else {
17791782
assert(Program->State == _pi_program::Object ||
17801783
Program->State == _pi_program::Exe ||
@@ -1806,8 +1809,9 @@ pi_result piProgramGetInfo(pi_program Program, pi_program_info ParamName,
18061809
uint8_t **PBinary = pi_cast<uint8_t **>(ParamValue);
18071810
if (!PBinary[0])
18081811
break;
1809-
if (Program->State == _pi_program::IL) {
1810-
// Nothing to do (see comments above for PI_PROGRAM_INFO_BINARY_SIZES).
1812+
if (Program->State == _pi_program::IL ||
1813+
Program->State == _pi_program::Native) {
1814+
std::memcpy(PBinary[0], Program->Code.get(), Program->CodeLength);
18111815
} else {
18121816
assert(Program->State == _pi_program::Object ||
18131817
Program->State == _pi_program::Exe ||
@@ -1827,10 +1831,11 @@ pi_result piProgramGetInfo(pi_program Program, pi_program_info ParamName,
18271831
case PI_PROGRAM_INFO_NUM_KERNELS: {
18281832
uint32_t NumKernels;
18291833
if (Program->State == _pi_program::IL ||
1834+
Program->State == _pi_program::Native ||
18301835
Program->State == _pi_program::Object) {
1831-
// The OpenCL spec says this case isn't supported, but it doesn't say
1832-
// what to do if it happens. Returning zero seems reasonable.
1833-
NumKernels = 0;
1836+
// The OpenCL spec says to return CL_INVALID_PROGRAM_EXECUTABLE in this
1837+
// case, but there is no corresponding PI error code.
1838+
return PI_INVALID_OPERATION;
18341839
} else {
18351840
assert(Program->State == _pi_program::Exe ||
18361841
Program->State == _pi_program::LinkedExe);
@@ -1850,8 +1855,11 @@ pi_result piProgramGetInfo(pi_program Program, pi_program_info ParamName,
18501855
try {
18511856
std::string PINames{""};
18521857
if (Program->State == _pi_program::IL ||
1858+
Program->State == _pi_program::Native ||
18531859
Program->State == _pi_program::Object) {
1854-
// Nothing to do (see comment for PI_PROGRAM_INFO_NUM_KERNELS above).
1860+
// The OpenCL spec says to return CL_INVALID_PROGRAM_EXECUTABLE in this
1861+
// case, but there is no corresponding PI error code.
1862+
return PI_INVALID_OPERATION;
18551863
} else {
18561864
assert(Program->State == _pi_program::Exe ||
18571865
Program->State == _pi_program::LinkedExe);
@@ -1996,26 +2004,79 @@ pi_result piProgramCompile(
19962004
const pi_program *InputHeaders, const char **HeaderIncludeNames,
19972005
void (*PFnNotify)(pi_program Program, void *UserData), void *UserData) {
19982006

1999-
// We only support one device with Level Zero, and we don't support input
2000-
// headers.
2001-
assert(Program);
2002-
assert(NumDevices == 1);
2003-
assert(DeviceList);
2004-
assert(NumInputHeaders == 0);
2005-
assert(Program);
2006-
assert(!PFnNotify && !UserData);
2007+
// The OpenCL spec says this should return CL_INVALID_PROGRAM, but there is
2008+
// no corresponding PI error code.
2009+
if (!Program)
2010+
return PI_INVALID_OPERATION;
20072011

2008-
// It is only valid to compile a program if it was loaded from IL.
2012+
// It's only valid to compile a program created from IL (we don't support
2013+
// programs created from source code).
2014+
//
2015+
// The OpenCL spec says that the header parameters are ignored when compiling
2016+
// IL programs, so we don't validate them.
20092017
if (Program->State != _pi_program::IL)
20102018
return PI_INVALID_OPERATION;
2011-
assert(Program->ILBytes);
20122019

2013-
// Translate collected specialization constants.
2020+
// These aren't supported.
2021+
assert(!PFnNotify && !UserData);
2022+
2023+
pi_result res = compileOrBuild(Program, NumDevices, DeviceList, Options);
2024+
if (res != PI_SUCCESS)
2025+
return res;
2026+
2027+
Program->State = _pi_program::Object;
2028+
return PI_SUCCESS;
2029+
}
2030+
2031+
pi_result piProgramBuild(pi_program Program, pi_uint32 NumDevices,
2032+
const pi_device *DeviceList, const char *Options,
2033+
void (*PFnNotify)(pi_program Program, void *UserData),
2034+
void *UserData) {
2035+
2036+
// The OpenCL spec says this should return CL_INVALID_PROGRAM, but there is
2037+
// no corresponding PI error code.
2038+
if (!Program)
2039+
return PI_INVALID_OPERATION;
2040+
2041+
// It is legal to build a program created from either IL or from native
2042+
// device code.
2043+
if (Program->State != _pi_program::IL &&
2044+
Program->State != _pi_program::Native)
2045+
return PI_INVALID_OPERATION;
2046+
2047+
// These aren't supported.
2048+
assert(!PFnNotify && !UserData);
2049+
2050+
pi_result res = compileOrBuild(Program, NumDevices, DeviceList, Options);
2051+
if (res != PI_SUCCESS)
2052+
return res;
2053+
2054+
Program->State = _pi_program::Exe;
2055+
return PI_SUCCESS;
2056+
}
2057+
2058+
// Perform common operations for compiling or building a program.
2059+
static pi_result compileOrBuild(pi_program Program, pi_uint32 NumDevices,
2060+
const pi_device *DeviceList,
2061+
const char *Options) {
2062+
2063+
if ((NumDevices && !DeviceList) || (!NumDevices && DeviceList))
2064+
return PI_INVALID_VALUE;
2065+
2066+
// We only support one device with Level Zero.
2067+
assert(NumDevices == 1 && DeviceList);
2068+
2069+
// We should have either IL or native device code.
2070+
assert(Program->Code);
2071+
2072+
// Specialization constants are used only if the program was created from
2073+
// IL. Translate them to the Level Zero format.
20142074
ze_module_constants_t ZeSpecConstants = {};
2015-
std::vector<uint32_t> ZeSpecContantsIds(Program->ZeSpecConstants.size());
2016-
std::vector<uint64_t> ZeSpecContantsValues(Program->ZeSpecConstants.size());
2017-
{
2075+
if (Program->State == _pi_program::IL) {
20182076
std::lock_guard<std::mutex> Guard(Program->MutexZeSpecConstants);
2077+
2078+
std::vector<uint32_t> ZeSpecContantsIds(Program->ZeSpecConstants.size());
2079+
std::vector<uint64_t> ZeSpecContantsValues(Program->ZeSpecConstants.size());
20192080
ZeSpecConstants.numConstants = Program->ZeSpecConstants.size();
20202081
auto ZeSpecContantsIdsIt = ZeSpecContantsIds.begin();
20212082
auto ZeSpecContantsValuesIt = ZeSpecContantsValues.begin();
@@ -2029,12 +2090,14 @@ pi_result piProgramCompile(
20292090
ZeSpecConstants.pConstantValues = ZeSpecContantsValues.data();
20302091
}
20312092

2032-
// Ask Level Zero to build the IL and load the native code onto the device.
2093+
// Ask Level Zero to build and load the native code onto the device.
20332094
ze_module_desc_t ZeModuleDesc = {};
20342095
ZeModuleDesc.version = ZE_MODULE_DESC_VERSION_CURRENT;
2035-
ZeModuleDesc.format = ZE_MODULE_FORMAT_IL_SPIRV;
2036-
ZeModuleDesc.inputSize = Program->ILLength;
2037-
ZeModuleDesc.pInputModule = Program->ILBytes.get();
2096+
ZeModuleDesc.format = (Program->State == _pi_program::IL)
2097+
? ZE_MODULE_FORMAT_IL_SPIRV
2098+
: ZE_MODULE_FORMAT_NATIVE;
2099+
ZeModuleDesc.inputSize = Program->CodeLength;
2100+
ZeModuleDesc.pInputModule = Program->Code.get();
20382101
ZeModuleDesc.pBuildFlags = Options;
20392102
ZeModuleDesc.pConstants = &ZeSpecConstants;
20402103

@@ -2050,30 +2113,14 @@ pi_result piProgramCompile(
20502113
ZE_CALL(zeModuleGetProperties(ZeModule, &ZeModuleProps));
20512114
Program->HasImports = (ZeModuleProps.flags & ZE_MODULE_PROPERTY_FLAG_IMPORTS);
20522115

2053-
// The program is now in the Object state. We no longer need the IL.
2054-
Program->State = _pi_program::Object;
2055-
Program->ILBytes.reset();
2116+
// We no longer need the IL / native code.
2117+
// The caller must set the State to Object or Exe as appropriate.
2118+
Program->Code.reset();
20562119
Program->ZeModule = ZeModule;
20572120
Program->ZeBuildLog = ZeBuildLog;
20582121
return PI_SUCCESS;
20592122
}
20602123

2061-
pi_result piProgramBuild(pi_program Program, pi_uint32 NumDevices,
2062-
const pi_device *DeviceList, const char *Options,
2063-
void (*PFnNotify)(pi_program Program, void *UserData),
2064-
void *UserData) {
2065-
2066-
// On Level Zero, there's no real difference between compiling and building.
2067-
// We just assume that the resulting program is an executable in the case of
2068-
// building.
2069-
pi_result res = piProgramCompile(Program, NumDevices, DeviceList, Options, 0,
2070-
nullptr, nullptr, PFnNotify, UserData);
2071-
if (res != PI_SUCCESS)
2072-
return res;
2073-
Program->State = _pi_program::Exe;
2074-
return PI_SUCCESS;
2075-
}
2076-
20772124
pi_result piProgramGetBuildInfo(pi_program Program, pi_device Device,
20782125
cl_program_build_info ParamName,
20792126
size_t ParamValueSize, void *ParamValue,

sycl/plugins/level_zero/pi_level_zero.hpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,11 @@ struct _pi_program : _pi_object {
316316
// is not yet compiled.
317317
IL,
318318

319+
// The program has been created by loading native code, but it has not yet
320+
// been built. This is equivalent to an OpenCL "program executable" that
321+
// is loaded via clCreateProgramWithBinary().
322+
Native,
323+
319324
// The program consists of native code (typically compiled from SPIR-v),
320325
// but it has unresolved external dependencies which need to be resolved
321326
// by linking with other Object state program(s). Programs in this state
@@ -392,7 +397,7 @@ struct _pi_program : _pi_object {
392397
NumMods = Prog->LinkedPrograms.size();
393398
IsDone = (It == Prog->LinkedPrograms.end());
394399
Mod = IsDone ? nullptr : (*It)->ZeModule;
395-
} else if (Prog->State == IL) {
400+
} else if (Prog->State == IL || Prog->State == Native) {
396401
NumMods = 0;
397402
IsDone = true;
398403
Mod = nullptr;
@@ -425,13 +430,13 @@ struct _pi_program : _pi_object {
425430
std::vector<LinkedReleaser>::iterator It;
426431
};
427432

428-
// Construct a program in IL state.
429-
_pi_program(pi_context Context, const void *InputIL, size_t InputILLength)
430-
: State(IL), Context(Context), ILBytes(new uint8_t[InputILLength]),
431-
ILLength(InputILLength), ZeModule(nullptr), HasImports(false),
433+
// Construct a program in IL or Native state.
434+
_pi_program(pi_context Context, const void *Input, size_t Length, state St)
435+
: State(St), Context(Context), Code(new uint8_t[Length]),
436+
CodeLength(Length), ZeModule(nullptr), HasImports(false),
432437
HasImportsAndIsLinked(false), ZeBuildLog(nullptr) {
433438

434-
std::memcpy(ILBytes.get(), InputIL, InputILLength);
439+
std::memcpy(Code.get(), Input, Length);
435440
}
436441

437442
// Construct a program in either Object or Exe state.
@@ -453,12 +458,11 @@ struct _pi_program : _pi_object {
453458
state State;
454459
pi_context Context; // Context of the program.
455460

456-
// Used for programs in IL state.
457-
std::unique_ptr<uint8_t[]> ILBytes; // Array containing raw IL.
458-
size_t ILLength; // Size (bytes) of the array.
461+
// Used for programs in IL or Native states.
462+
std::unique_ptr<uint8_t[]> Code; // Array containing raw IL / native code.
463+
size_t CodeLength; // Size (bytes) of the array.
459464

460465
// Level Zero specialization constants, used for programs in IL state.
461-
// Access to this member is protected by Mutex.
462466
std::unordered_map<uint32_t, uint64_t> ZeSpecConstants;
463467
std::mutex MutexZeSpecConstants; // Protects access to this field.
464468

0 commit comments

Comments
 (0)