44
44
#include " llvm/Transforms/Utils/Cloning.h"
45
45
46
46
#include < algorithm>
47
+ #include < map>
47
48
#include < memory>
48
49
#include < string>
49
50
#include < vector>
50
51
51
52
using namespace llvm ;
52
53
53
54
using string_vector = std::vector<std::string>;
55
+ using FuncPtrVector = std::vector<const Function *>;
56
+ using EntryPointsSet = std::map<StringRef, FuncPtrVector>;
57
+ using ModuleUPtr = std::unique_ptr<Module>;
54
58
using PropSetRegTy = llvm::util::PropertySetRegistry;
59
+ using StringRefVector = std::vector<StringRef>;
55
60
56
61
namespace {
57
62
@@ -184,7 +189,6 @@ cl::opt<bool> EmitOnlyKernelsAsEntryPoints{
184
189
cl::cat (PostLinkCat), cl::init (false )};
185
190
186
191
struct ImagePropSaveInfo {
187
- bool SetSpecConstAtRT;
188
192
bool SpecConstsMet;
189
193
bool EmitKernelParamInfo;
190
194
bool EmitProgramMetadata;
@@ -314,10 +318,9 @@ bool isEntryPoint(const Function &F) {
314
318
// ResKernelModuleMap which maps some key to a group of entry points. Each such
315
319
// group along with IR it depends on (globals, functions from its call graph,
316
320
// ...) will constitute a separate module.
317
- void collectEntryPointToModuleMap (
318
- const Module &M,
319
- std::map<StringRef, std::vector<const Function *>> &ResKernelModuleMap,
320
- KernelMapEntryScope EntryScope) {
321
+ void collectEntryPointToModuleMap (const Module &M,
322
+ EntryPointsSet &ResKernelModuleMap,
323
+ KernelMapEntryScope EntryScope) {
321
324
322
325
// Only process module entry points:
323
326
for (const auto &F : M.functions ()) {
@@ -358,11 +361,11 @@ HasAssertStatus hasAssertInFunctionCallGraph(const Function *Func) {
358
361
// true - if there is an assertion in underlying functions,
359
362
// false - if there are definetely no assertions in underlying functions.
360
363
static std::map<const Function *, bool > hasAssertionInCallGraphMap;
361
- std::vector< const Function *> FuncCallStack;
364
+ FuncPtrVector FuncCallStack;
362
365
363
- static std::vector< const Function *> isIndirectlyCalledInGraph;
366
+ static FuncPtrVector isIndirectlyCalledInGraph;
364
367
365
- std::vector< const Function *> Workstack;
368
+ FuncPtrVector Workstack;
366
369
Workstack.push_back (Func);
367
370
368
371
while (!Workstack.empty ()) {
@@ -437,11 +440,11 @@ HasAssertStatus hasAssertInFunctionCallGraph(const Function *Func) {
437
440
return No_Assert;
438
441
}
439
442
440
- std::vector<StringRef> getKernelNamesUsingAssert (const Module &M) {
441
- std::vector<StringRef> Result;
443
+ StringRefVector getKernelNamesUsingAssert (const Module &M) {
444
+ StringRefVector Result;
442
445
443
446
bool HasIndirectlyCalledAssert = false ;
444
- std::vector< const Function *> Kernels;
447
+ FuncPtrVector Kernels;
445
448
for (const auto &F : M.functions ()) {
446
449
// TODO: handle SYCL_EXTERNAL functions for dynamic linkage.
447
450
// TODO: handle function pointers.
@@ -489,14 +492,13 @@ std::vector<uint32_t> getKernelReqdWorkGroupSizeMetadata(const Function &Func) {
489
492
}
490
493
491
494
// Input parameter KernelModuleMap is a map containing groups of entry points
492
- // with same values of the sycl-module-id attribute. ResSymbolsLists is a vector
495
+ // with same values of the sycl-module-id attribute. Return value is a vector
493
496
// of entry points names lists. Each vector element is a string with entry point
494
497
// names from the same module separated by \n.
495
498
// The function saves names of entry points from one group to a single
496
499
// std::string and stores this string to the ResSymbolsLists vector.
497
- void collectSymbolsLists (
498
- const std::map<StringRef, std::vector<const Function *>> &KernelModuleMap,
499
- string_vector &ResSymbolsLists) {
500
+ string_vector collectSymbolsLists (const EntryPointsSet &KernelModuleMap) {
501
+ string_vector ResSymbolsLists{};
500
502
for (const auto &It : KernelModuleMap) {
501
503
std::string SymbolsList;
502
504
for (const auto &F : It.second ) {
@@ -505,11 +507,12 @@ void collectSymbolsLists(
505
507
}
506
508
ResSymbolsLists.push_back (std::move (SymbolsList));
507
509
}
510
+ return ResSymbolsLists;
508
511
}
509
512
510
513
struct ResultModule {
511
514
StringRef KernelModuleName;
512
- std::unique_ptr<Module> ModulePtr;
515
+ ModuleUPtr ModulePtr;
513
516
};
514
517
515
518
// Input parameter KernelModuleMap is a map containing groups of entry points
@@ -518,14 +521,14 @@ struct ResultModule {
518
521
// ResModules is a vector of pairs of kernel module names and produced modules.
519
522
// The function splits input LLVM IR module M into smaller ones and stores them
520
523
// to the ResModules vector.
521
- void splitModule (
522
- const Module &M,
523
- const std::map<StringRef, std:: vector<const Function *>> &KernelModuleMap,
524
- std::vector<ResultModule> &ResModules) {
524
+ std::vector<ResultModule> splitModule (const Module &M,
525
+ const EntryPointsSet &KernelModuleMap) {
526
+ std::vector<ResultModule> ResModules{};
527
+
525
528
for (const auto &It : KernelModuleMap) {
526
529
// For each group of entry points collect all dependencies.
527
530
SetVector<const GlobalValue *> GVs;
528
- std::vector< const Function *> Workqueue;
531
+ FuncPtrVector Workqueue;
529
532
530
533
for (const auto &F : It.second ) {
531
534
GVs.insert (F);
@@ -556,7 +559,7 @@ void splitModule(
556
559
ValueToValueMapTy VMap;
557
560
// Clone definitions only for needed globals. Others will be added as
558
561
// declarations and removed later.
559
- std::unique_ptr<Module> MClone = CloneModule (
562
+ ModuleUPtr MClone = CloneModule (
560
563
M, VMap, [&](const GlobalValue *GV) { return GVs.count (GV); });
561
564
562
565
// TODO: Use the new PassManager instead?
@@ -570,6 +573,8 @@ void splitModule(
570
573
// Save results.
571
574
ResModules.push_back ({It.first , std::move (MClone)});
572
575
}
576
+
577
+ return ResModules;
573
578
}
574
579
575
580
std::string makeResultFileName (Twine Ext, int I, StringRef Suffix) {
@@ -614,10 +619,10 @@ string_vector saveResultModules(const std::vector<ResultModule> &ResModules,
614
619
return Res;
615
620
}
616
621
617
- string_vector saveDeviceImageProperty (
618
- const std::vector<ResultModule> &ResultModules,
619
- const std::map<StringRef, std::vector< const Function *>> &KernelModuleMap,
620
- const ImagePropSaveInfo &ImgPSInfo) {
622
+ string_vector
623
+ saveDeviceImageProperty ( const std::vector<ResultModule> &ResultModules,
624
+ const EntryPointsSet &KernelModuleMap,
625
+ const ImagePropSaveInfo &ImgPSInfo) {
621
626
string_vector Res;
622
627
legacy::PassManager GetSYCLDeviceLibReqMask;
623
628
auto *SDLReqMaskLegacyPass = new SYCLDeviceLibReqMaskPass ();
@@ -711,7 +716,7 @@ string_vector saveDeviceImageProperty(
711
716
PropSet[PropSetRegTy::SYCL_MISC_PROP].insert ({" isEsimdImage" , true });
712
717
713
718
{
714
- std::vector<StringRef> FuncNames = getKernelNamesUsingAssert (M);
719
+ StringRefVector FuncNames = getKernelNamesUsingAssert (M);
715
720
for (const StringRef &FName : FuncNames)
716
721
PropSet[PropSetRegTy::SYCL_ASSERT_USED].insert ({FName, true });
717
722
}
@@ -781,8 +786,7 @@ void LowerEsimdConstructs(Module &M) {
781
786
782
787
using TableFiles = std::map<StringRef, string_vector>;
783
788
784
- TableFiles processOneModule (std::unique_ptr<Module> M, bool IsEsimd,
785
- bool SyclAndEsimdCode) {
789
+ TableFiles processOneModule (ModuleUPtr M, bool IsEsimd, bool SyclAndEsimdCode) {
786
790
TableFiles TblFiles;
787
791
if (!M)
788
792
return TblFiles;
@@ -804,7 +808,7 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
804
808
if (IsEsimd && LowerEsimd)
805
809
LowerEsimdConstructs (*M);
806
810
807
- std::map<StringRef, std::vector< const Function *>> GlobalsSet;
811
+ EntryPointsSet GlobalsSet;
808
812
809
813
bool DoSplit = SplitMode.getNumOccurrences () > 0 ;
810
814
@@ -813,19 +817,21 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
813
817
collectEntryPointToModuleMap (*M, GlobalsSet, Scope);
814
818
}
815
819
816
- std::vector<ResultModule> ResultModules ;
820
+ StringRef FileSuffix = IsEsimd ? " esimd_ " : " " ;
817
821
818
- bool DoSpecConst = SpecConstLower.getNumOccurrences () > 0 ;
819
- bool SpecConstsMet = false ;
820
- bool SetSpecConstAtRT = DoSpecConst && (SpecConstLower == SC_USE_RT_VAL);
822
+ std::vector<ResultModule> ResultModules;
821
823
822
824
if (DoSplit)
823
- splitModule (*M, GlobalsSet, ResultModules );
825
+ ResultModules = splitModule (*M, GlobalsSet);
824
826
// post-link always produces a code result, even if it is unmodified input
825
827
if (ResultModules.empty ())
826
828
ResultModules.push_back ({GLOBAL_SCOPE_NAME, std::move (M)});
827
829
830
+ bool DoSpecConst = SpecConstLower.getNumOccurrences () > 0 ;
831
+ bool SpecConstsMet = false ;
832
+
828
833
if (DoSpecConst) {
834
+ bool SetSpecConstAtRT = (SpecConstLower == SC_USE_RT_VAL);
829
835
ModulePassManager RunSpecConst;
830
836
ModuleAnalysisManager MAM;
831
837
SpecConstantsPass SCP (SetSpecConstAtRT);
@@ -851,23 +857,22 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
851
857
// Reuse input module with only regular SYCL kernels if there were
852
858
// no spec constants and no splitting.
853
859
// We cannot reuse input module for ESIMD code since it was transformed.
854
- bool CanReuseInputModule = !SpecConstsMet && (ResultModules.size () == 1 ) &&
855
- !SyclAndEsimdCode && !IsEsimd &&
856
- !IsLLVMUsedRemoved;
857
- string_vector Files =
858
- CanReuseInputModule
859
- ? string_vector{InputFilename}
860
- : saveResultModules (ResultModules, IsEsimd ? " esimd_" : " " );
860
+ bool CanReuseInputModule = !SyclAndEsimdCode && !IsEsimd &&
861
+ !IsLLVMUsedRemoved && !SpecConstsMet &&
862
+ (ResultModules.size () == 1 );
863
+ string_vector Files = CanReuseInputModule
864
+ ? string_vector{InputFilename}
865
+ : saveResultModules (ResultModules, FileSuffix);
861
866
862
867
// "Code" column is always output
863
868
std::copy (Files.begin (), Files.end (),
864
869
std::back_inserter (TblFiles[COL_CODE]));
865
870
}
866
871
867
872
{
868
- ImagePropSaveInfo ImgPSInfo = {SetSpecConstAtRT, SpecConstsMet ,
869
- EmitKernelParamInfo, EmitProgramMetadata ,
870
- EmitExportedSymbols, IsEsimd};
873
+ ImagePropSaveInfo ImgPSInfo = {SpecConstsMet, EmitKernelParamInfo ,
874
+ EmitProgramMetadata, EmitExportedSymbols ,
875
+ IsEsimd};
871
876
string_vector Files =
872
877
saveDeviceImageProperty (ResultModules, GlobalsSet, ImgPSInfo);
873
878
std::copy (Files.begin (), Files.end (),
@@ -876,29 +881,28 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
876
881
877
882
if (DoSymGen) {
878
883
// extract symbols per each module
879
- string_vector ResultSymbolsLists;
880
- collectSymbolsLists (GlobalsSet, ResultSymbolsLists);
884
+ string_vector ResultSymbolsLists = collectSymbolsLists (GlobalsSet);
881
885
if (ResultSymbolsLists.empty ()) {
882
886
// push empty symbols list for consistency
883
887
assert (ResultModules.size () == 1 );
884
888
ResultSymbolsLists.push_back (" " );
885
889
}
886
890
string_vector Files =
887
- saveResultSymbolsLists (ResultSymbolsLists, IsEsimd ? " esimd_ " : " " );
891
+ saveResultSymbolsLists (ResultSymbolsLists, FileSuffix );
888
892
std::copy (Files.begin (), Files.end (),
889
893
std::back_inserter (TblFiles[COL_SYM]));
890
894
}
891
895
892
896
return TblFiles;
893
897
}
894
898
895
- using ModulePair = std::pair<std::unique_ptr<Module>, std::unique_ptr<Module> >;
899
+ using ModulePair = std::pair<ModuleUPtr, ModuleUPtr >;
896
900
897
901
// This function splits a module with a mix of SYCL and ESIMD kernels
898
902
// into two separate modules.
899
- ModulePair splitSyclEsimd (std::unique_ptr<Module> M) {
900
- std::vector< const Function *> SyclFunctions;
901
- std::vector< const Function *> EsimdFunctions;
903
+ ModulePair splitSyclEsimd (ModuleUPtr M) {
904
+ FuncPtrVector SyclFunctions;
905
+ FuncPtrVector EsimdFunctions;
902
906
// Collect information about the SYCL and ESIMD functions in the module.
903
907
// Only process module entry points.
904
908
for (const auto &F : M->functions ()) {
@@ -912,30 +916,29 @@ ModulePair splitSyclEsimd(std::unique_ptr<Module> M) {
912
916
913
917
// If only SYCL kernels or only ESIMD kernels, no splitting needed.
914
918
if (EsimdFunctions.empty ())
915
- return std::make_pair (std::move (M), std::unique_ptr<Module> (nullptr ));
919
+ return std::make_pair (std::move (M), ModuleUPtr (nullptr ));
916
920
917
921
if (SyclFunctions.empty ())
918
- return std::make_pair (std::unique_ptr<Module> (nullptr ), std::move (M));
922
+ return std::make_pair (ModuleUPtr (nullptr ), std::move (M));
919
923
920
924
// Key values in KernelModuleMap are not significant, but they define the
921
925
// order, in which entry points are processed in the splitModule function. The
922
926
// caller of the splitSyclEsimd function expects a pair of 1-Sycl and 2-Esimd
923
927
// modules, hence the strings names below.
924
- std::map<StringRef, std::vector< const Function *>> KernelModuleMap (
928
+ EntryPointsSet KernelModuleMap (
925
929
{{" 1-SYCL" , SyclFunctions}, {" 2-ESIMD" , EsimdFunctions}});
926
- std::vector<ResultModule> ResultModules;
927
- splitModule (*M, KernelModuleMap, ResultModules);
930
+ std::vector<ResultModule> ResultModules = splitModule (*M, KernelModuleMap);
928
931
assert (ResultModules.size () == 2 );
929
932
return std::make_pair (std::move (ResultModules[0 ].ModulePtr ),
930
933
std::move (ResultModules[1 ].ModulePtr ));
931
934
}
932
935
933
- TableFiles processInputModule (std::unique_ptr<Module> M) {
936
+ TableFiles processInputModule (ModuleUPtr M) {
934
937
if (!SplitEsimd)
935
938
return processOneModule (std::move (M), false , false );
936
939
937
- std::unique_ptr<Module> SyclModule;
938
- std::unique_ptr<Module> EsimdModule;
940
+ ModuleUPtr SyclModule;
941
+ ModuleUPtr EsimdModule;
939
942
std::tie (SyclModule, EsimdModule) = splitSyclEsimd (std::move (M));
940
943
941
944
// Do we have both Sycl and Esimd code?
@@ -1058,7 +1061,7 @@ int main(int argc, char **argv) {
1058
1061
return 1 ;
1059
1062
}
1060
1063
SMDiagnostic Err;
1061
- std::unique_ptr<Module> M = parseIRFile (InputFilename, Err, Context);
1064
+ ModuleUPtr M = parseIRFile (InputFilename, Err, Context);
1062
1065
// It is OK to use raw pointer here as we control that it does not outlive M
1063
1066
// or objects it is moved to
1064
1067
Module *MPtr = M.get ();
0 commit comments