Skip to content

Commit 8ccf00b

Browse files
committed
[sycl-post-link][NFC] Small cleanups and cosmetic changes
Signed-off-by: Mikhail Lychkov <[email protected]>
1 parent c855fd1 commit 8ccf00b

File tree

1 file changed

+64
-61
lines changed

1 file changed

+64
-61
lines changed

llvm/tools/sycl-post-link/sycl-post-link.cpp

Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,19 @@
4444
#include "llvm/Transforms/Utils/Cloning.h"
4545

4646
#include <algorithm>
47+
#include <map>
4748
#include <memory>
4849
#include <string>
4950
#include <vector>
5051

5152
using namespace llvm;
5253

5354
using string_vector = std::vector<std::string>;
55+
using FuncPtrVector = std::vector<const Function *>;
56+
using EntryPointsSet = std::map<StringRef, FuncPtrVector>;
57+
using ModuleUPtr = std::unique_ptr<Module>;
5458
using PropSetRegTy = llvm::util::PropertySetRegistry;
59+
using StringRefVector = std::vector<StringRef>;
5560

5661
namespace {
5762

@@ -184,7 +189,6 @@ cl::opt<bool> EmitOnlyKernelsAsEntryPoints{
184189
cl::cat(PostLinkCat), cl::init(false)};
185190

186191
struct ImagePropSaveInfo {
187-
bool SetSpecConstAtRT;
188192
bool SpecConstsMet;
189193
bool EmitKernelParamInfo;
190194
bool EmitProgramMetadata;
@@ -314,10 +318,9 @@ bool isEntryPoint(const Function &F) {
314318
// ResKernelModuleMap which maps some key to a group of entry points. Each such
315319
// group along with IR it depends on (globals, functions from its call graph,
316320
// ...) will constitute a separate module.
317-
void collectEntryPointToModuleMap(
318-
const Module &M,
319-
std::map<StringRef, std::vector<const Function *>> &ResKernelModuleMap,
320-
KernelMapEntryScope EntryScope) {
321+
void collectEntryPointToModuleMap(const Module &M,
322+
EntryPointsSet &ResKernelModuleMap,
323+
KernelMapEntryScope EntryScope) {
321324

322325
// Only process module entry points:
323326
for (const auto &F : M.functions()) {
@@ -358,11 +361,11 @@ HasAssertStatus hasAssertInFunctionCallGraph(const Function *Func) {
358361
// true - if there is an assertion in underlying functions,
359362
// false - if there are definetely no assertions in underlying functions.
360363
static std::map<const Function *, bool> hasAssertionInCallGraphMap;
361-
std::vector<const Function *> FuncCallStack;
364+
FuncPtrVector FuncCallStack;
362365

363-
static std::vector<const Function *> isIndirectlyCalledInGraph;
366+
static FuncPtrVector isIndirectlyCalledInGraph;
364367

365-
std::vector<const Function *> Workstack;
368+
FuncPtrVector Workstack;
366369
Workstack.push_back(Func);
367370

368371
while (!Workstack.empty()) {
@@ -437,11 +440,11 @@ HasAssertStatus hasAssertInFunctionCallGraph(const Function *Func) {
437440
return No_Assert;
438441
}
439442

440-
std::vector<StringRef> getKernelNamesUsingAssert(const Module &M) {
441-
std::vector<StringRef> Result;
443+
StringRefVector getKernelNamesUsingAssert(const Module &M) {
444+
StringRefVector Result;
442445

443446
bool HasIndirectlyCalledAssert = false;
444-
std::vector<const Function *> Kernels;
447+
FuncPtrVector Kernels;
445448
for (const auto &F : M.functions()) {
446449
// TODO: handle SYCL_EXTERNAL functions for dynamic linkage.
447450
// TODO: handle function pointers.
@@ -489,14 +492,13 @@ std::vector<uint32_t> getKernelReqdWorkGroupSizeMetadata(const Function &Func) {
489492
}
490493

491494
// Input parameter KernelModuleMap is a map containing groups of entry points
492-
// with same values of the sycl-module-id attribute. ResSymbolsLists is a vector
495+
// with same values of the sycl-module-id attribute. Return value is a vector
493496
// of entry points names lists. Each vector element is a string with entry point
494497
// names from the same module separated by \n.
495498
// The function saves names of entry points from one group to a single
496499
// std::string and stores this string to the ResSymbolsLists vector.
497-
void collectSymbolsLists(
498-
const std::map<StringRef, std::vector<const Function *>> &KernelModuleMap,
499-
string_vector &ResSymbolsLists) {
500+
string_vector collectSymbolsLists(const EntryPointsSet &KernelModuleMap) {
501+
string_vector ResSymbolsLists{};
500502
for (const auto &It : KernelModuleMap) {
501503
std::string SymbolsList;
502504
for (const auto &F : It.second) {
@@ -505,11 +507,12 @@ void collectSymbolsLists(
505507
}
506508
ResSymbolsLists.push_back(std::move(SymbolsList));
507509
}
510+
return ResSymbolsLists;
508511
}
509512

510513
struct ResultModule {
511514
StringRef KernelModuleName;
512-
std::unique_ptr<Module> ModulePtr;
515+
ModuleUPtr ModulePtr;
513516
};
514517

515518
// Input parameter KernelModuleMap is a map containing groups of entry points
@@ -518,14 +521,14 @@ struct ResultModule {
518521
// ResModules is a vector of pairs of kernel module names and produced modules.
519522
// The function splits input LLVM IR module M into smaller ones and stores them
520523
// to the ResModules vector.
521-
void splitModule(
522-
const Module &M,
523-
const std::map<StringRef, std::vector<const Function *>> &KernelModuleMap,
524-
std::vector<ResultModule> &ResModules) {
524+
std::vector<ResultModule> splitModule(const Module &M,
525+
const EntryPointsSet &KernelModuleMap) {
526+
std::vector<ResultModule> ResModules{};
527+
525528
for (const auto &It : KernelModuleMap) {
526529
// For each group of entry points collect all dependencies.
527530
SetVector<const GlobalValue *> GVs;
528-
std::vector<const Function *> Workqueue;
531+
FuncPtrVector Workqueue;
529532

530533
for (const auto &F : It.second) {
531534
GVs.insert(F);
@@ -556,7 +559,7 @@ void splitModule(
556559
ValueToValueMapTy VMap;
557560
// Clone definitions only for needed globals. Others will be added as
558561
// declarations and removed later.
559-
std::unique_ptr<Module> MClone = CloneModule(
562+
ModuleUPtr MClone = CloneModule(
560563
M, VMap, [&](const GlobalValue *GV) { return GVs.count(GV); });
561564

562565
// TODO: Use the new PassManager instead?
@@ -570,6 +573,8 @@ void splitModule(
570573
// Save results.
571574
ResModules.push_back({It.first, std::move(MClone)});
572575
}
576+
577+
return ResModules;
573578
}
574579

575580
std::string makeResultFileName(Twine Ext, int I, StringRef Suffix) {
@@ -614,10 +619,10 @@ string_vector saveResultModules(const std::vector<ResultModule> &ResModules,
614619
return Res;
615620
}
616621

617-
string_vector saveDeviceImageProperty(
618-
const std::vector<ResultModule> &ResultModules,
619-
const std::map<StringRef, std::vector<const Function *>> &KernelModuleMap,
620-
const ImagePropSaveInfo &ImgPSInfo) {
622+
string_vector
623+
saveDeviceImageProperty(const std::vector<ResultModule> &ResultModules,
624+
const EntryPointsSet &KernelModuleMap,
625+
const ImagePropSaveInfo &ImgPSInfo) {
621626
string_vector Res;
622627
legacy::PassManager GetSYCLDeviceLibReqMask;
623628
auto *SDLReqMaskLegacyPass = new SYCLDeviceLibReqMaskPass();
@@ -711,7 +716,7 @@ string_vector saveDeviceImageProperty(
711716
PropSet[PropSetRegTy::SYCL_MISC_PROP].insert({"isEsimdImage", true});
712717

713718
{
714-
std::vector<StringRef> FuncNames = getKernelNamesUsingAssert(M);
719+
StringRefVector FuncNames = getKernelNamesUsingAssert(M);
715720
for (const StringRef &FName : FuncNames)
716721
PropSet[PropSetRegTy::SYCL_ASSERT_USED].insert({FName, true});
717722
}
@@ -781,8 +786,7 @@ void LowerEsimdConstructs(Module &M) {
781786

782787
using TableFiles = std::map<StringRef, string_vector>;
783788

784-
TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
785-
bool SyclAndEsimdCode) {
789+
TableFiles processOneModule(ModuleUPtr M, bool IsEsimd, bool SyclAndEsimdCode) {
786790
TableFiles TblFiles;
787791
if (!M)
788792
return TblFiles;
@@ -804,7 +808,7 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
804808
if (IsEsimd && LowerEsimd)
805809
LowerEsimdConstructs(*M);
806810

807-
std::map<StringRef, std::vector<const Function *>> GlobalsSet;
811+
EntryPointsSet GlobalsSet;
808812

809813
bool DoSplit = SplitMode.getNumOccurrences() > 0;
810814

@@ -813,19 +817,21 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
813817
collectEntryPointToModuleMap(*M, GlobalsSet, Scope);
814818
}
815819

816-
std::vector<ResultModule> ResultModules;
820+
StringRef FileSuffix = IsEsimd ? "esimd_" : "";
817821

818-
bool DoSpecConst = SpecConstLower.getNumOccurrences() > 0;
819-
bool SpecConstsMet = false;
820-
bool SetSpecConstAtRT = DoSpecConst && (SpecConstLower == SC_USE_RT_VAL);
822+
std::vector<ResultModule> ResultModules;
821823

822824
if (DoSplit)
823-
splitModule(*M, GlobalsSet, ResultModules);
825+
ResultModules = splitModule(*M, GlobalsSet);
824826
// post-link always produces a code result, even if it is unmodified input
825827
if (ResultModules.empty())
826828
ResultModules.push_back({GLOBAL_SCOPE_NAME, std::move(M)});
827829

830+
bool DoSpecConst = SpecConstLower.getNumOccurrences() > 0;
831+
bool SpecConstsMet = false;
832+
828833
if (DoSpecConst) {
834+
bool SetSpecConstAtRT = (SpecConstLower == SC_USE_RT_VAL);
829835
ModulePassManager RunSpecConst;
830836
ModuleAnalysisManager MAM;
831837
SpecConstantsPass SCP(SetSpecConstAtRT);
@@ -851,23 +857,22 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
851857
// Reuse input module with only regular SYCL kernels if there were
852858
// no spec constants and no splitting.
853859
// We cannot reuse input module for ESIMD code since it was transformed.
854-
bool CanReuseInputModule = !SpecConstsMet && (ResultModules.size() == 1) &&
855-
!SyclAndEsimdCode && !IsEsimd &&
856-
!IsLLVMUsedRemoved;
857-
string_vector Files =
858-
CanReuseInputModule
859-
? string_vector{InputFilename}
860-
: saveResultModules(ResultModules, IsEsimd ? "esimd_" : "");
860+
bool CanReuseInputModule = !SyclAndEsimdCode && !IsEsimd &&
861+
!IsLLVMUsedRemoved && !SpecConstsMet &&
862+
(ResultModules.size() == 1);
863+
string_vector Files = CanReuseInputModule
864+
? string_vector{InputFilename}
865+
: saveResultModules(ResultModules, FileSuffix);
861866

862867
// "Code" column is always output
863868
std::copy(Files.begin(), Files.end(),
864869
std::back_inserter(TblFiles[COL_CODE]));
865870
}
866871

867872
{
868-
ImagePropSaveInfo ImgPSInfo = {SetSpecConstAtRT, SpecConstsMet,
869-
EmitKernelParamInfo, EmitProgramMetadata,
870-
EmitExportedSymbols, IsEsimd};
873+
ImagePropSaveInfo ImgPSInfo = {SpecConstsMet, EmitKernelParamInfo,
874+
EmitProgramMetadata, EmitExportedSymbols,
875+
IsEsimd};
871876
string_vector Files =
872877
saveDeviceImageProperty(ResultModules, GlobalsSet, ImgPSInfo);
873878
std::copy(Files.begin(), Files.end(),
@@ -876,29 +881,28 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
876881

877882
if (DoSymGen) {
878883
// extract symbols per each module
879-
string_vector ResultSymbolsLists;
880-
collectSymbolsLists(GlobalsSet, ResultSymbolsLists);
884+
string_vector ResultSymbolsLists = collectSymbolsLists(GlobalsSet);
881885
if (ResultSymbolsLists.empty()) {
882886
// push empty symbols list for consistency
883887
assert(ResultModules.size() == 1);
884888
ResultSymbolsLists.push_back("");
885889
}
886890
string_vector Files =
887-
saveResultSymbolsLists(ResultSymbolsLists, IsEsimd ? "esimd_" : "");
891+
saveResultSymbolsLists(ResultSymbolsLists, FileSuffix);
888892
std::copy(Files.begin(), Files.end(),
889893
std::back_inserter(TblFiles[COL_SYM]));
890894
}
891895

892896
return TblFiles;
893897
}
894898

895-
using ModulePair = std::pair<std::unique_ptr<Module>, std::unique_ptr<Module>>;
899+
using ModulePair = std::pair<ModuleUPtr, ModuleUPtr>;
896900

897901
// This function splits a module with a mix of SYCL and ESIMD kernels
898902
// into two separate modules.
899-
ModulePair splitSyclEsimd(std::unique_ptr<Module> M) {
900-
std::vector<const Function *> SyclFunctions;
901-
std::vector<const Function *> EsimdFunctions;
903+
ModulePair splitSyclEsimd(ModuleUPtr M) {
904+
FuncPtrVector SyclFunctions;
905+
FuncPtrVector EsimdFunctions;
902906
// Collect information about the SYCL and ESIMD functions in the module.
903907
// Only process module entry points.
904908
for (const auto &F : M->functions()) {
@@ -912,30 +916,29 @@ ModulePair splitSyclEsimd(std::unique_ptr<Module> M) {
912916

913917
// If only SYCL kernels or only ESIMD kernels, no splitting needed.
914918
if (EsimdFunctions.empty())
915-
return std::make_pair(std::move(M), std::unique_ptr<Module>(nullptr));
919+
return std::make_pair(std::move(M), ModuleUPtr(nullptr));
916920

917921
if (SyclFunctions.empty())
918-
return std::make_pair(std::unique_ptr<Module>(nullptr), std::move(M));
922+
return std::make_pair(ModuleUPtr(nullptr), std::move(M));
919923

920924
// Key values in KernelModuleMap are not significant, but they define the
921925
// order, in which entry points are processed in the splitModule function. The
922926
// caller of the splitSyclEsimd function expects a pair of 1-Sycl and 2-Esimd
923927
// modules, hence the strings names below.
924-
std::map<StringRef, std::vector<const Function *>> KernelModuleMap(
928+
EntryPointsSet KernelModuleMap(
925929
{{"1-SYCL", SyclFunctions}, {"2-ESIMD", EsimdFunctions}});
926-
std::vector<ResultModule> ResultModules;
927-
splitModule(*M, KernelModuleMap, ResultModules);
930+
std::vector<ResultModule> ResultModules = splitModule(*M, KernelModuleMap);
928931
assert(ResultModules.size() == 2);
929932
return std::make_pair(std::move(ResultModules[0].ModulePtr),
930933
std::move(ResultModules[1].ModulePtr));
931934
}
932935

933-
TableFiles processInputModule(std::unique_ptr<Module> M) {
936+
TableFiles processInputModule(ModuleUPtr M) {
934937
if (!SplitEsimd)
935938
return processOneModule(std::move(M), false, false);
936939

937-
std::unique_ptr<Module> SyclModule;
938-
std::unique_ptr<Module> EsimdModule;
940+
ModuleUPtr SyclModule;
941+
ModuleUPtr EsimdModule;
939942
std::tie(SyclModule, EsimdModule) = splitSyclEsimd(std::move(M));
940943

941944
// Do we have both Sycl and Esimd code?
@@ -1058,7 +1061,7 @@ int main(int argc, char **argv) {
10581061
return 1;
10591062
}
10601063
SMDiagnostic Err;
1061-
std::unique_ptr<Module> M = parseIRFile(InputFilename, Err, Context);
1064+
ModuleUPtr M = parseIRFile(InputFilename, Err, Context);
10621065
// It is OK to use raw pointer here as we control that it does not outlive M
10631066
// or objects it is moved to
10641067
Module *MPtr = M.get();

0 commit comments

Comments
 (0)