@@ -531,6 +531,19 @@ static bool getDeviceLibraries(const ArgList &Args,
531
531
return FoundUnknownLib;
532
532
}
533
533
534
+ static Expected<std::unique_ptr<llvm::Module>>
535
+ loadBitcodeLibrary (StringRef LibPath, LLVMContext &Context) {
536
+ SMDiagnostic Diag;
537
+ std::unique_ptr<llvm::Module> Lib = parseIRFile (LibPath, Diag, Context);
538
+ if (!Lib) {
539
+ std::string DiagMsg;
540
+ raw_string_ostream SOS (DiagMsg);
541
+ Diag.print (/* ProgName=*/ nullptr , SOS);
542
+ return createStringError (DiagMsg);
543
+ }
544
+ return std::move (Lib);
545
+ }
546
+
534
547
Error jit_compiler::linkDeviceLibraries (llvm::Module &Module,
535
548
const InputArgList &UserArgList,
536
549
std::string &BuildLog) {
@@ -558,16 +571,13 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
558
571
for (const std::string &LibName : LibNames) {
559
572
std::string LibPath = DPCPPRoot + " /lib/" + LibName;
560
573
561
- SMDiagnostic Diag;
562
- std::unique_ptr<llvm::Module> Lib = parseIRFile (LibPath, Diag, Context);
563
- if (!Lib) {
564
- std::string DiagMsg;
565
- raw_string_ostream SOS (DiagMsg);
566
- Diag.print (/* ProgName=*/ nullptr , SOS);
567
- return createStringError (DiagMsg);
574
+ auto LibOrErr = loadBitcodeLibrary (LibPath, Context);
575
+ if (!LibOrErr) {
576
+ return LibOrErr.takeError ();
568
577
}
569
578
570
- if (Linker::linkModules (Module, std::move (Lib), Linker::LinkOnlyNeeded)) {
579
+ if (Linker::linkModules (Module, std::move (*LibOrErr),
580
+ Linker::LinkOnlyNeeded)) {
571
581
return createStringError (" Unable to link device library %s: %s" ,
572
582
LibPath.c_str (), BuildLog.c_str ());
573
583
}
@@ -607,6 +617,31 @@ static IRSplitMode getDeviceCodeSplitMode(const InputArgList &UserArgList) {
607
617
return SPLIT_AUTO;
608
618
}
609
619
620
+ static void encodeProperties (PropertySetRegistry &Properties,
621
+ RTCDevImgInfo &DevImgInfo) {
622
+ const auto &PropertySets = Properties.getPropSets ();
623
+
624
+ DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size ()};
625
+ for (auto [KV, FrozenPropSet] :
626
+ zip_equal (PropertySets, DevImgInfo.Properties )) {
627
+ const auto &PropertySetName = KV.first ;
628
+ const auto &PropertySet = KV.second ;
629
+ FrozenPropSet =
630
+ FrozenPropertySet{PropertySetName.str (), PropertySet.size ()};
631
+ for (auto [KV2, FrozenProp] :
632
+ zip_equal (PropertySet, FrozenPropSet.Values )) {
633
+ const auto &PropertyName = KV2.first ;
634
+ const auto &PropertyValue = KV2.second ;
635
+ FrozenProp = PropertyValue.getType () == PropertyValue::Type::UINT32
636
+ ? FrozenPropertyValue{PropertyName.str (),
637
+ PropertyValue.asUint32 ()}
638
+ : FrozenPropertyValue{
639
+ PropertyName.str (), PropertyValue.asRawByteArray (),
640
+ PropertyValue.getRawByteArraySize ()};
641
+ }
642
+ };
643
+ }
644
+
610
645
Expected<PostLinkResult>
611
646
jit_compiler::performPostLink (std::unique_ptr<llvm::Module> Module,
612
647
const InputArgList &UserArgList) {
@@ -637,9 +672,9 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
637
672
// Otherwise: Port over the `removeSYCLKernelsConstRefArray` and
638
673
// `removeDeviceGlobalFromCompilerUsed` methods.
639
674
640
- assert (!isModuleUsingAsan (*Module));
641
- // Otherwise: Need to instrument each image scope device globals if the module
642
- // has been instrumented by sanitizer pass .
675
+ assert (!( isModuleUsingAsan (*Module) || isModuleUsingMsan (*Module) ||
676
+ isModuleUsingTsan (*Module)));
677
+ // Otherwise: Run `SanitizerKernelMetadataPass` .
643
678
644
679
// Transform Joint Matrix builtin calls to align them with SPIR-V friendly
645
680
// LLVM IR specification.
@@ -668,6 +703,7 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
668
703
// `-fno-sycl-device-code-split-esimd` as a prerequisite for compiling
669
704
// `invoke_simd` code.
670
705
706
+ bool IsBF16DeviceLibUsed = false ;
671
707
while (Splitter->hasMoreSplits ()) {
672
708
ModuleDesc MDesc = Splitter->nextSplit ();
673
709
@@ -701,35 +737,58 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
701
737
/* DeviceGlobals=*/ false };
702
738
PropertySetRegistry Properties =
703
739
computeModuleProperties (MDesc.getModule (), MDesc.entries (), PropReq);
740
+
741
+ // When the split mode is none, the required work group size will be added
742
+ // to the whole module, which will make the runtime unable to launch the
743
+ // other kernels in the module that have different required work group
744
+ // sizes or no required work group sizes. So we need to remove the
745
+ // required work group size metadata in this case.
746
+ if (SplitMode == module_split::SPLIT_NONE) {
747
+ Properties.remove (PropSetRegTy::SYCL_DEVICE_REQUIREMENTS,
748
+ PropSetRegTy::PROPERTY_REQD_WORK_GROUP_SIZE);
749
+ }
750
+
704
751
// TODO: Manually add `compile_target` property as in
705
752
// `saveModuleProperties`?
706
- const auto &PropertySets = Properties.getPropSets ();
707
-
708
- DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size ()};
709
- for (auto [KV, FrozenPropSet] :
710
- zip_equal (PropertySets, DevImgInfo.Properties )) {
711
- const auto &PropertySetName = KV.first ;
712
- const auto &PropertySet = KV.second ;
713
- FrozenPropSet =
714
- FrozenPropertySet{PropertySetName.str (), PropertySet.size ()};
715
- for (auto [KV2, FrozenProp] :
716
- zip_equal (PropertySet, FrozenPropSet.Values )) {
717
- const auto &PropertyName = KV2.first ;
718
- const auto &PropertyValue = KV2.second ;
719
- FrozenProp =
720
- PropertyValue.getType () == PropertyValue::Type::UINT32
721
- ? FrozenPropertyValue{PropertyName.str (),
722
- PropertyValue.asUint32 ()}
723
- : FrozenPropertyValue{PropertyName.str (),
724
- PropertyValue.asRawByteArray (),
725
- PropertyValue.getRawByteArraySize ()};
726
- }
727
- };
728
753
754
+ encodeProperties (Properties, DevImgInfo);
755
+
756
+ IsBF16DeviceLibUsed |= isSYCLDeviceLibBF16Used (MDesc.getModule ());
729
757
Modules.push_back (MDesc.releaseModulePtr ());
730
758
}
731
759
}
732
760
761
+ if (IsBF16DeviceLibUsed) {
762
+ const std::string &DPCPPRoot = getDPCPPRoot ();
763
+ if (DPCPPRoot == InvalidDPCPPRoot) {
764
+ return createStringError (" Could not locate DPCPP root directory" );
765
+ }
766
+
767
+ auto &Ctx = Modules.front ()->getContext ();
768
+ auto WrapLibraryInDevImg = [&](const std::string &LibName) -> Error {
769
+ std::string LibPath = DPCPPRoot + " /lib/" + LibName;
770
+ auto LibOrErr = loadBitcodeLibrary (LibPath, Ctx);
771
+ if (!LibOrErr) {
772
+ return LibOrErr.takeError ();
773
+ }
774
+
775
+ std::unique_ptr<llvm::Module> LibModule = std::move (*LibOrErr);
776
+ PropertySetRegistry Properties =
777
+ computeDeviceLibProperties (*LibModule, LibName);
778
+ encodeProperties (Properties, DevImgInfoVec.emplace_back ());
779
+ Modules.push_back (std::move (LibModule));
780
+
781
+ return Error::success ();
782
+ };
783
+
784
+ if (auto Err = WrapLibraryInDevImg (" libsycl-fallback-bfloat16.bc" )) {
785
+ return std::move (Err);
786
+ }
787
+ if (auto Err = WrapLibraryInDevImg (" libsycl-native-bfloat16.bc" )) {
788
+ return std::move (Err);
789
+ }
790
+ }
791
+
733
792
assert (DevImgInfoVec.size () == Modules.size ());
734
793
RTCBundleInfo BundleInfo;
735
794
BundleInfo.DevImgInfos = DynArray<RTCDevImgInfo>{DevImgInfoVec.size ()};
0 commit comments