Skip to content

[sycl-post-link][NFC] Another portion of small refactorings #4807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 115 additions & 106 deletions llvm/tools/sycl-post-link/sycl-post-link.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
using namespace llvm;

using string_vector = std::vector<std::string>;
using PropSetRegTy = llvm::util::PropertySetRegistry;

namespace {

Expand Down Expand Up @@ -183,8 +184,6 @@ cl::opt<bool> EmitOnlyKernelsAsEntryPoints{
cl::cat(PostLinkCat), cl::init(false)};

struct ImagePropSaveInfo {
bool NeedDeviceLibReqMask;
bool DoSpecConst;
bool SetSpecConstAtRT;
bool SpecConstsMet;
bool EmitKernelParamInfo;
Expand Down Expand Up @@ -219,20 +218,13 @@ enum KernelMapEntryScope {
Scope_Global // single entry in the map for all kernels
};

KernelMapEntryScope selectDeviceCodeSplitScopeAutomatically(const Module &M) {
if (IROutputOnly) {
// We allow enabling auto split mode even in presence of -ir-output-only
// flag, but in this case we are limited by it so we can't do any split at
// all.
return Scope_Global;
}

bool hasIndirectFunctionCalls(const Module &M) {
for (const auto &F : M.functions()) {
// There are functions marked with [[intel::device_indirectly_callable]]
// attribute, because it instructs us to make this function available to the
// whole program as it was compiled as a single module.
if (F.hasFnAttribute("referenced-indirectly"))
return Scope_Global;
return true;
if (F.isDeclaration())
continue;
// There are indirect calls in the module, which means that we don't know
Expand All @@ -241,18 +233,47 @@ KernelMapEntryScope selectDeviceCodeSplitScopeAutomatically(const Module &M) {
for (const auto &I : instructions(F)) {
if (auto *CI = dyn_cast<CallInst>(&I))
if (!CI->getCalledFunction())
return Scope_Global;
return true;
}

// Function pointer is used somewhere. Follow the same rule as above.
for (const auto *U : F.users())
if (!isa<CallInst>(U))
return Scope_Global;
return true;
}

// At the moment, we assume that per-source split is the best way of splitting
// device code and can always be used execpt for cases handled above.
return Scope_PerModule;
return false;
}

KernelMapEntryScope selectDeviceCodeSplitScope(const Module &M) {
bool DoSplit = SplitMode.getNumOccurrences() > 0;
if (DoSplit) {
switch (SplitMode) {
case SPLIT_PER_TU:
return Scope_PerModule;

case SPLIT_PER_KERNEL:
return Scope_PerKernel;

case SPLIT_AUTO: {
if (IROutputOnly) {
// We allow enabling auto split mode even in presence of -ir-output-only
// flag, but in this case we are limited by it so we can't do any split
// at all.
return Scope_Global;
}

if (hasIndirectFunctionCalls(M))
return Scope_Global;

// At the moment, we assume that per-source split is the best way of
// splitting device code and can always be used except for cases handled
// above.
return Scope_PerModule;
}
}
}
return Scope_Global;
}

// Return true if the function is a SPIRV or SYCL builtin, e.g.
Expand Down Expand Up @@ -411,6 +432,41 @@ HasAssertStatus hasAssertInFunctionCallGraph(const Function *Func) {
return No_Assert;
}

std::vector<StringRef> getKernelNamesUsingAssert(const Module &M) {
std::vector<StringRef> Result;

bool HasIndirectlyCalledAssert = false;
std::vector<const Function *> Kernels;
for (const auto &F : M.functions()) {
// TODO: handle SYCL_EXTERNAL functions for dynamic linkage.
// TODO: handle function pointers.
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
continue;

Kernels.push_back(&F);
if (HasIndirectlyCalledAssert)
continue;

HasAssertStatus HasAssert = hasAssertInFunctionCallGraph(&F);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at implementation of this method I think this is very inefficient approach to find kernels calling __devicelib_assert_fail. IMHO, bottom-up approach is way more efficient. I.e. start with __devicelib_assert_fail function declaration and travers users gathering kernels.
Current implementation goes over all functions in the module even if __devicelib_assert_fail not declared!
Please, consider refactoring this code to improve performance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fully agree with the suggestion, @mlychkov, could you please create a tracker for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll create it.

switch (HasAssert) {
case Assert:
Result.push_back(F.getName());
break;
case Assert_Indirect:
HasIndirectlyCalledAssert = true;
break;
case No_Assert:
break;
}
}

if (HasIndirectlyCalledAssert)
for (const auto *F : Kernels)
Result.push_back(F->getName());

return Result;
}

// Gets reqd_work_group_size information for function Func.
std::vector<uint32_t> getKernelReqdWorkGroupSizeMetadata(const Function &Func) {
auto ReqdWorkGroupSizeMD = Func.getMetadata("reqd_work_group_size");
Expand Down Expand Up @@ -545,7 +601,6 @@ string_vector saveResultModules(const std::vector<ResultModule> &ResModules,
string_vector Res;

for (size_t I = 0; I < ResModules.size(); ++I) {
std::error_code EC;
StringRef FileExt = (OutputAssembly) ? ".ll" : ".bc";
std::string CurOutFileName = makeResultFileName(FileExt, I, Suffix);
saveModule(*ResModules[I].ModulePtr, CurOutFileName);
Expand All @@ -560,40 +615,36 @@ string_vector saveDeviceImageProperty(
const ImagePropSaveInfo &ImgPSInfo) {
string_vector Res;
legacy::PassManager GetSYCLDeviceLibReqMask;
SYCLDeviceLibReqMaskPass *SDLReqMaskLegacyPass =
new SYCLDeviceLibReqMaskPass();
auto *SDLReqMaskLegacyPass = new SYCLDeviceLibReqMaskPass();
GetSYCLDeviceLibReqMask.add(SDLReqMaskLegacyPass);
for (size_t I = 0; I < ResultModules.size(); ++I) {
llvm::util::PropertySetRegistry PropSet;
if (ImgPSInfo.NeedDeviceLibReqMask) {
GetSYCLDeviceLibReqMask.run(*ResultModules[I].ModulePtr);
Module &M = *ResultModules[I].ModulePtr;
PropSetRegTy PropSet;

{
GetSYCLDeviceLibReqMask.run(M);
uint32_t MRMask = SDLReqMaskLegacyPass->getSYCLDeviceLibReqMask();
std::map<StringRef, uint32_t> RMEntry = {{"DeviceLibReqMask", MRMask}};
PropSet.add(llvm::util::PropertySetRegistry::SYCL_DEVICELIB_REQ_MASK,
RMEntry);
PropSet.add(PropSetRegTy::SYCL_DEVICELIB_REQ_MASK, RMEntry);
}
if (ImgPSInfo.DoSpecConst) {
if (ImgPSInfo.SpecConstsMet) {
// extract spec constant maps per each module
SpecIDMapTy TmpSpecIDMap;
SpecConstantsPass::collectSpecConstantMetadata(
*ResultModules[I].ModulePtr, TmpSpecIDMap);
PropSet.add(
llvm::util::PropertySetRegistry::SYCL_SPECIALIZATION_CONSTANTS,
TmpSpecIDMap);

// Add property with the default values of spec constants only in native
// (default) mode.
if (!ImgPSInfo.SetSpecConstAtRT) {
std::vector<char> DefaultValues;
SpecConstantsPass::collectSpecConstantDefaultValuesMetadata(
*ResultModules[I].ModulePtr, DefaultValues);
PropSet.add(llvm::util::PropertySetRegistry::
SYCL_SPEC_CONSTANTS_DEFAULT_VALUES,
"all", DefaultValues);
}

if (ImgPSInfo.SpecConstsMet) {
// extract spec constant maps per each module
SpecIDMapTy TmpSpecIDMap;
SpecConstantsPass::collectSpecConstantMetadata(M, TmpSpecIDMap);
PropSet.add(PropSetRegTy::SYCL_SPECIALIZATION_CONSTANTS, TmpSpecIDMap);

// Add property with the default values of spec constants only in native
// (default) mode.
if (!ImgPSInfo.SetSpecConstAtRT) {
std::vector<char> DefaultValues;
SpecConstantsPass::collectSpecConstantDefaultValuesMetadata(
M, DefaultValues);
PropSet.add(PropSetRegTy::SYCL_SPEC_CONSTANTS_DEFAULT_VALUES, "all",
DefaultValues);
}
}

if (ImgPSInfo.EmitKernelParamInfo) {
// extract kernel parameter optimization info per module
ModuleAnalysisManager MAM;
Expand All @@ -603,12 +654,11 @@ string_vector saveDeviceImageProperty(

MAM.registerPass([&] { return SYCLKernelParamOptInfoAnalysis(); });
SYCLKernelParamOptInfo PInfo =
MAM.getResult<SYCLKernelParamOptInfoAnalysis>(
*ResultModules[I].ModulePtr);
MAM.getResult<SYCLKernelParamOptInfoAnalysis>(M);

// convert analysis results into properties and record them
llvm::util::PropertySet &Props =
PropSet[llvm::util::PropertySetRegistry::SYCL_KERNEL_PARAM_OPT_INFO];
PropSet[PropSetRegTy::SYCL_KERNEL_PARAM_OPT_INFO];

for (const auto &NameInfoPair : PInfo) {
const llvm::BitVector &Bits = NameInfoPair.second;
Expand All @@ -623,15 +673,16 @@ string_vector saveDeviceImageProperty(
NameInfoPair.first, llvm::util::PropertyValue(Data, DataBitSize)));
}
}

if (ImgPSInfo.EmitExportedSymbols) {
// For each result module, extract the exported functions
auto ModuleFunctionsIt =
KernelModuleMap.find(ResultModules[I].KernelModuleName);
if (ModuleFunctionsIt != KernelModuleMap.end()) {
for (const auto &F : ModuleFunctionsIt->second) {
if (F->getCallingConv() == CallingConv::SPIR_FUNC) {
PropSet[llvm::util::PropertySetRegistry::SYCL_EXPORTED_SYMBOLS]
.insert({F->getName(), true});
PropSet[PropSetRegTy::SYCL_EXPORTED_SYMBOLS].insert(
{F->getName(), true});
}
}
}
Expand All @@ -641,11 +692,10 @@ string_vector saveDeviceImageProperty(
// properties have been written.
SmallVector<std::string, 4> MetadataNames;
if (ImgPSInfo.EmitProgramMetadata) {
auto &ProgramMetadata =
PropSet[llvm::util::PropertySetRegistry::SYCL_PROGRAM_METADATA];
auto &ProgramMetadata = PropSet[PropSetRegTy::SYCL_PROGRAM_METADATA];

// Add reqd_work_group_size information to program metadata
for (const Function &Func : ResultModules[I].ModulePtr->functions()) {
for (const Function &Func : M.functions()) {
std::vector<uint32_t> KernelReqdWorkGroupSize =
getKernelReqdWorkGroupSizeMetadata(Func);
if (KernelReqdWorkGroupSize.empty())
Expand All @@ -655,44 +705,13 @@ string_vector saveDeviceImageProperty(
}
}

if (ImgPSInfo.IsEsimdKernel) {
PropSet[llvm::util::PropertySetRegistry::SYCL_MISC_PROP].insert(
{"isEsimdImage", true});
}
if (ImgPSInfo.IsEsimdKernel)
PropSet[PropSetRegTy::SYCL_MISC_PROP].insert({"isEsimdImage", true});

{
Module *M = ResultModules[I].ModulePtr.get();
bool HasIndirectlyCalledAssert = false;
std::vector<const Function *> Kernels;
for (const auto &F : M->functions()) {
// TODO: handle SYCL_EXTERNAL functions for dynamic linkage.
// TODO: handle function pointers.
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
continue;

Kernels.push_back(&F);
if (HasIndirectlyCalledAssert)
continue;

HasAssertStatus HasAssert = hasAssertInFunctionCallGraph(&F);
switch (HasAssert) {
case Assert:
PropSet[llvm::util::PropertySetRegistry::SYCL_ASSERT_USED].insert(
{F.getName(), true});
break;
case Assert_Indirect:
HasIndirectlyCalledAssert = true;
break;
case No_Assert:
break;
}
}

if (HasIndirectlyCalledAssert) {
for (const auto *F : Kernels)
PropSet[llvm::util::PropertySetRegistry::SYCL_ASSERT_USED].insert(
{F->getName(), true});
}
std::vector<StringRef> FuncNames = getKernelNamesUsingAssert(M);
for (const StringRef &FName : FuncNames)
PropSet[PropSetRegTy::SYCL_ASSERT_USED].insert({FName, true});
}

std::error_code EC;
Expand Down Expand Up @@ -786,23 +805,15 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
std::map<StringRef, std::vector<const Function *>> GlobalsSet;

bool DoSplit = SplitMode.getNumOccurrences() > 0;
bool DoSpecConst = SpecConstLower.getNumOccurrences() > 0;

if (DoSplit || DoSymGen) {
KernelMapEntryScope Scope = Scope_Global;
if (DoSplit) {
if (SplitMode == SPLIT_AUTO)
Scope = selectDeviceCodeSplitScopeAutomatically(*M);
else
Scope =
SplitMode == SPLIT_PER_KERNEL ? Scope_PerKernel : Scope_PerModule;
}
KernelMapEntryScope Scope = selectDeviceCodeSplitScope(*M);
collectEntryPointToModuleMap(*M, GlobalsSet, Scope);
}

std::vector<ResultModule> ResultModules;
string_vector ResultSymbolsLists;

bool DoSpecConst = SpecConstLower.getNumOccurrences() > 0;
bool SpecConstsMet = false;
bool SetSpecConstAtRT = DoSpecConst && (SpecConstLower == SC_USE_RT_VAL);

Expand Down Expand Up @@ -852,21 +863,18 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
}

{
ImagePropSaveInfo ImgPSInfo = {true,
DoSpecConst,
SetSpecConstAtRT,
SpecConstsMet,
EmitKernelParamInfo,
EmitProgramMetadata,
EmitExportedSymbols,
IsEsimd};
ImagePropSaveInfo ImgPSInfo = {SetSpecConstAtRT, SpecConstsMet,
EmitKernelParamInfo, EmitProgramMetadata,
EmitExportedSymbols, IsEsimd};
string_vector Files =
saveDeviceImageProperty(ResultModules, GlobalsSet, ImgPSInfo);
std::copy(Files.begin(), Files.end(),
std::back_inserter(TblFiles[COL_PROPS]));
}

if (DoSymGen) {
// extract symbols per each module
string_vector ResultSymbolsLists;
collectSymbolsLists(GlobalsSet, ResultSymbolsLists);
if (ResultSymbolsLists.empty()) {
// push empty symbols list for consistency
Expand All @@ -878,6 +886,7 @@ TableFiles processOneModule(std::unique_ptr<Module> M, bool IsEsimd,
std::copy(Files.begin(), Files.end(),
std::back_inserter(TblFiles[COL_SYM]));
}

return TblFiles;
}

Expand Down