Skip to content

Commit 77e110d

Browse files
authored
[SYCL][RTC] Adopt recent changes from sycl-post-link (#17447)
Incorporates recent changes to `sycl-post-link` into the RTC-specific version: - #16729 - #16236 - #17211 --------- Signed-off-by: Julian Oppermann <[email protected]>
1 parent 76c6653 commit 77e110d

File tree

3 files changed

+140
-43
lines changed

3 files changed

+140
-43
lines changed

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 92 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,19 @@ static bool getDeviceLibraries(const ArgList &Args,
531531
return FoundUnknownLib;
532532
}
533533

534+
static Expected<std::unique_ptr<llvm::Module>>
535+
loadBitcodeLibrary(StringRef LibPath, LLVMContext &Context) {
536+
SMDiagnostic Diag;
537+
std::unique_ptr<llvm::Module> Lib = parseIRFile(LibPath, Diag, Context);
538+
if (!Lib) {
539+
std::string DiagMsg;
540+
raw_string_ostream SOS(DiagMsg);
541+
Diag.print(/*ProgName=*/nullptr, SOS);
542+
return createStringError(DiagMsg);
543+
}
544+
return std::move(Lib);
545+
}
546+
534547
Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
535548
const InputArgList &UserArgList,
536549
std::string &BuildLog) {
@@ -558,16 +571,13 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
558571
for (const std::string &LibName : LibNames) {
559572
std::string LibPath = DPCPPRoot + "/lib/" + LibName;
560573

561-
SMDiagnostic Diag;
562-
std::unique_ptr<llvm::Module> Lib = parseIRFile(LibPath, Diag, Context);
563-
if (!Lib) {
564-
std::string DiagMsg;
565-
raw_string_ostream SOS(DiagMsg);
566-
Diag.print(/*ProgName=*/nullptr, SOS);
567-
return createStringError(DiagMsg);
574+
auto LibOrErr = loadBitcodeLibrary(LibPath, Context);
575+
if (!LibOrErr) {
576+
return LibOrErr.takeError();
568577
}
569578

570-
if (Linker::linkModules(Module, std::move(Lib), Linker::LinkOnlyNeeded)) {
579+
if (Linker::linkModules(Module, std::move(*LibOrErr),
580+
Linker::LinkOnlyNeeded)) {
571581
return createStringError("Unable to link device library %s: %s",
572582
LibPath.c_str(), BuildLog.c_str());
573583
}
@@ -607,6 +617,31 @@ static IRSplitMode getDeviceCodeSplitMode(const InputArgList &UserArgList) {
607617
return SPLIT_AUTO;
608618
}
609619

620+
static void encodeProperties(PropertySetRegistry &Properties,
621+
RTCDevImgInfo &DevImgInfo) {
622+
const auto &PropertySets = Properties.getPropSets();
623+
624+
DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size()};
625+
for (auto [KV, FrozenPropSet] :
626+
zip_equal(PropertySets, DevImgInfo.Properties)) {
627+
const auto &PropertySetName = KV.first;
628+
const auto &PropertySet = KV.second;
629+
FrozenPropSet =
630+
FrozenPropertySet{PropertySetName.str(), PropertySet.size()};
631+
for (auto [KV2, FrozenProp] :
632+
zip_equal(PropertySet, FrozenPropSet.Values)) {
633+
const auto &PropertyName = KV2.first;
634+
const auto &PropertyValue = KV2.second;
635+
FrozenProp = PropertyValue.getType() == PropertyValue::Type::UINT32
636+
? FrozenPropertyValue{PropertyName.str(),
637+
PropertyValue.asUint32()}
638+
: FrozenPropertyValue{
639+
PropertyName.str(), PropertyValue.asRawByteArray(),
640+
PropertyValue.getRawByteArraySize()};
641+
}
642+
};
643+
}
644+
610645
Expected<PostLinkResult>
611646
jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
612647
const InputArgList &UserArgList) {
@@ -637,9 +672,9 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
637672
// Otherwise: Port over the `removeSYCLKernelsConstRefArray` and
638673
// `removeDeviceGlobalFromCompilerUsed` methods.
639674

640-
assert(!isModuleUsingAsan(*Module));
641-
// Otherwise: Need to instrument each image scope device globals if the module
642-
// has been instrumented by sanitizer pass.
675+
assert(!(isModuleUsingAsan(*Module) || isModuleUsingMsan(*Module) ||
676+
isModuleUsingTsan(*Module)));
677+
// Otherwise: Run `SanitizerKernelMetadataPass`.
643678

644679
// Transform Joint Matrix builtin calls to align them with SPIR-V friendly
645680
// LLVM IR specification.
@@ -668,6 +703,7 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
668703
// `-fno-sycl-device-code-split-esimd` as a prerequisite for compiling
669704
// `invoke_simd` code.
670705

706+
bool IsBF16DeviceLibUsed = false;
671707
while (Splitter->hasMoreSplits()) {
672708
ModuleDesc MDesc = Splitter->nextSplit();
673709

@@ -701,35 +737,58 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
701737
/*DeviceGlobals=*/false};
702738
PropertySetRegistry Properties =
703739
computeModuleProperties(MDesc.getModule(), MDesc.entries(), PropReq);
740+
741+
// When the split mode is none, the required work group size will be added
742+
// to the whole module, which will make the runtime unable to launch the
743+
// other kernels in the module that have different required work group
744+
// sizes or no required work group sizes. So we need to remove the
745+
// required work group size metadata in this case.
746+
if (SplitMode == module_split::SPLIT_NONE) {
747+
Properties.remove(PropSetRegTy::SYCL_DEVICE_REQUIREMENTS,
748+
PropSetRegTy::PROPERTY_REQD_WORK_GROUP_SIZE);
749+
}
750+
704751
// TODO: Manually add `compile_target` property as in
705752
// `saveModuleProperties`?
706-
const auto &PropertySets = Properties.getPropSets();
707-
708-
DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size()};
709-
for (auto [KV, FrozenPropSet] :
710-
zip_equal(PropertySets, DevImgInfo.Properties)) {
711-
const auto &PropertySetName = KV.first;
712-
const auto &PropertySet = KV.second;
713-
FrozenPropSet =
714-
FrozenPropertySet{PropertySetName.str(), PropertySet.size()};
715-
for (auto [KV2, FrozenProp] :
716-
zip_equal(PropertySet, FrozenPropSet.Values)) {
717-
const auto &PropertyName = KV2.first;
718-
const auto &PropertyValue = KV2.second;
719-
FrozenProp =
720-
PropertyValue.getType() == PropertyValue::Type::UINT32
721-
? FrozenPropertyValue{PropertyName.str(),
722-
PropertyValue.asUint32()}
723-
: FrozenPropertyValue{PropertyName.str(),
724-
PropertyValue.asRawByteArray(),
725-
PropertyValue.getRawByteArraySize()};
726-
}
727-
};
728753

754+
encodeProperties(Properties, DevImgInfo);
755+
756+
IsBF16DeviceLibUsed |= isSYCLDeviceLibBF16Used(MDesc.getModule());
729757
Modules.push_back(MDesc.releaseModulePtr());
730758
}
731759
}
732760

761+
if (IsBF16DeviceLibUsed) {
762+
const std::string &DPCPPRoot = getDPCPPRoot();
763+
if (DPCPPRoot == InvalidDPCPPRoot) {
764+
return createStringError("Could not locate DPCPP root directory");
765+
}
766+
767+
auto &Ctx = Modules.front()->getContext();
768+
auto WrapLibraryInDevImg = [&](const std::string &LibName) -> Error {
769+
std::string LibPath = DPCPPRoot + "/lib/" + LibName;
770+
auto LibOrErr = loadBitcodeLibrary(LibPath, Ctx);
771+
if (!LibOrErr) {
772+
return LibOrErr.takeError();
773+
}
774+
775+
std::unique_ptr<llvm::Module> LibModule = std::move(*LibOrErr);
776+
PropertySetRegistry Properties =
777+
computeDeviceLibProperties(*LibModule, LibName);
778+
encodeProperties(Properties, DevImgInfoVec.emplace_back());
779+
Modules.push_back(std::move(LibModule));
780+
781+
return Error::success();
782+
};
783+
784+
if (auto Err = WrapLibraryInDevImg("libsycl-fallback-bfloat16.bc")) {
785+
return std::move(Err);
786+
}
787+
if (auto Err = WrapLibraryInDevImg("libsycl-native-bfloat16.bc")) {
788+
return std::move(Err);
789+
}
790+
}
791+
733792
assert(DevImgInfoVec.size() == Modules.size());
734793
RTCBundleInfo BundleInfo;
735794
BundleInfo.DevImgInfos = DynArray<RTCDevImgInfo>{DevImgInfoVec.size()};

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,13 +1837,17 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
18371837
return {};
18381838
}
18391839

1840-
static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
1840+
static bool shouldSkipEmptyImage(sycl_device_binary RawImg, bool IsRTC) {
18411841
// For bfloat16 device library image, we should keep it. However, in some
18421842
// scenario, __sycl_register_lib can be called multiple times and the same
18431843
// bfloat16 device library image may be handled multiple times which is not
18441844
// needed. 2 static bool variables are created to record whether native or
18451845
// fallback bfloat16 device library image has been handled, if yes, we just
18461846
// need to skip it.
1847+
// We cannot prevent redundant loads of device library images if they are part
1848+
// of a runtime-compiled device binary, as these will be freed when the
1849+
// corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
1850+
// on the presence of RTC device library images.
18471851
sycl_device_binary_property_set ImgPS;
18481852
static bool IsNativeBF16DeviceLibHandled = false;
18491853
static bool IsFallbackBF16DeviceLibHandled = false;
@@ -1861,8 +1865,13 @@ static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
18611865
if (ImgP == ImgPS->PropertiesEnd)
18621866
return true;
18631867

1864-
// A valid bfloat16 device library image is found here, need to check
1865-
// wheter it has been handled already.
1868+
// A valid bfloat16 device library image is found here.
1869+
// If it originated from RTC, we cannot skip it, but do not mark it as
1870+
// being present.
1871+
if (IsRTC)
1872+
return false;
1873+
1874+
// Otherwise, we need to check whether it has been handled already.
18661875
uint32_t BF16NativeVal = DeviceBinaryProperty(ImgP).asUint32();
18671876
if (((BF16NativeVal == 0) && IsFallbackBF16DeviceLibHandled) ||
18681877
((BF16NativeVal == 1) && IsNativeBF16DeviceLibHandled))
@@ -1879,14 +1888,33 @@ static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
18791888
return true;
18801889
}
18811890

1891+
static bool isCompiledAtRuntime(sycl_device_binaries DeviceBinary) {
1892+
// Check whether the first device binary contains a legacy format offload
1893+
// entry with a `$` in its name.
1894+
if (DeviceBinary->NumDeviceBinaries > 0) {
1895+
sycl_device_binary Binary = DeviceBinary->DeviceBinaries;
1896+
if (Binary->EntriesBegin != Binary->EntriesEnd) {
1897+
sycl_offload_entry Entry = Binary->EntriesBegin;
1898+
if (!Entry->IsNewOffloadEntryType() &&
1899+
std::string_view{Entry->name}.find('$') != std::string_view::npos) {
1900+
return true;
1901+
}
1902+
}
1903+
}
1904+
return false;
1905+
}
1906+
18821907
void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
18831908
const bool DumpImages = std::getenv("SYCL_DUMP_IMAGES") && !m_UseSpvFile;
1909+
const bool IsRTC = isCompiledAtRuntime(DeviceBinary);
18841910
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
18851911
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
18861912
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
18871913
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
1888-
// Treat the image as empty one
1889-
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg))
1914+
// If the image does not contain kernels, skip it unless it is one of the
1915+
// bfloat16 device libraries, and it wasn't loaded before or resulted from
1916+
// runtime compilation.
1917+
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg, IsRTC))
18901918
continue;
18911919

18921920
std::unique_ptr<RTDeviceBinaryImage> Img;
@@ -2081,15 +2109,19 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
20812109
}
20822110

20832111
void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
2112+
bool IsRTC = isCompiledAtRuntime(DeviceBinary);
20842113
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
20852114
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
20862115
auto DevImgIt = m_DeviceImages.find(RawImg);
20872116
if (DevImgIt == m_DeviceImages.end())
20882117
continue;
20892118
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
20902119
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
2091-
// Treat the image as empty one
2092-
if (EntriesB == EntriesE)
2120+
// Skip clean up if there are no offload entries, unless `DeviceBinary`
2121+
// resulted from runtime compilation: Then, this is one of the `bfloat16`
2122+
// device libraries, so we want to make sure that the image and its exported
2123+
// symbols are removed from the program manager's maps.
2124+
if (EntriesB == EntriesE && !IsRTC)
20932125
continue;
20942126

20952127
RTDeviceBinaryImage *Img = DevImgIt->second.get();

sycl/test-e2e/KernelCompiler/sycl.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ void device_libs_kernel(float *ptr) {
134134
135135
// cl_intel_devicelib_imf
136136
ptr[3] = sycl::ext::intel::math::sqrt(ptr[3] * 2);
137+
138+
// cl_intel_devicelib_imf_bf16
139+
ptr[4] = sycl::ext::intel::math::float2bfloat16(ptr[4] * 0.5f);
140+
141+
// cl_intel_devicelib_bfloat16
142+
ptr[5] = sycl::ext::oneapi::bfloat16{ptr[5] / 0.25f};
137143
}
138144
)===";
139145

@@ -435,7 +441,7 @@ int test_device_libraries() {
435441
exe_kb kbExe = syclex::build(kbSrc);
436442

437443
sycl::kernel k = kbExe.ext_oneapi_get_kernel("device_libs_kernel");
438-
constexpr size_t nElem = 4;
444+
constexpr size_t nElem = 6;
439445
float *ptr = sycl::malloc_shared<float>(nElem, q);
440446
for (int i = 0; i < nElem; ++i)
441447
ptr[i] = 1.0f;
@@ -446,8 +452,8 @@ int test_device_libraries() {
446452
});
447453
q.wait_and_throw();
448454

449-
// Check that the kernel was executed. Given the {1.0, 1.0, 1.0, 1.0} input,
450-
// the expected result is approximately {0.84, 1.41, 0.0, 1.41}.
455+
// Check that the kernel was executed. Given the {1.0, ..., 1.0} input,
456+
// the expected result is approximately {0.84, 1.41, 0.0, 1.41, 0.5, 4.0}.
451457
for (unsigned i = 0; i < nElem; ++i) {
452458
std::cout << ptr[i] << ' ';
453459
assert(ptr[i] != 1.0f);

0 commit comments

Comments
 (0)