Skip to content

Commit 29030f9

Browse files
committed
[SYCL][NFC] Refactor SpecConstantsPass
This refactoring is needed to simplify upcoming functional changes. Outlined some code into a helper function Removed `getDefaultValue` function: it is unclear why default value of spec constant would be different for AOT and non-AOT flows. Removed unused SymGlob argument of getStringLiteralArg. Removed unneeded static keyword. Added `const` to a few arguments of different helpers within the pass.
1 parent 3b8af22 commit 29030f9

File tree

1 file changed

+33
-43
lines changed

1 file changed

+33
-43
lines changed

llvm/tools/sycl-post-link/SpecConstants.cpp

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,13 @@ constexpr char SPIRV_GET_SPEC_CONST_VAL[] = "__spirv_SpecConstant";
3333
// original symbolic spec constant ID.
3434
constexpr char SPEC_CONST_SYM_ID_MD_STRING[] = "SYCL_SPEC_CONST_SYM_ID";
3535

36-
static void AssertRelease(bool Cond, const char *Msg) {
36+
void AssertRelease(bool Cond, const char *Msg) {
3737
if (!Cond)
3838
report_fatal_error((Twine("SpecConstants.cpp: ") + Msg).str().c_str());
3939
}
4040

4141
StringRef getStringLiteralArg(const CallInst *CI, unsigned ArgNo,
42-
SmallVectorImpl<Instruction *> &DelInsts,
43-
GlobalVariable *&SymGlob) {
42+
SmallVectorImpl<Instruction *> &DelInsts) {
4443
Value *V = CI->getArgOperand(ArgNo)->stripPointerCasts();
4544

4645
if (auto *L = dyn_cast<LoadInst>(V)) {
@@ -95,7 +94,6 @@ StringRef getStringLiteralArg(const CallInst *CI, unsigned ArgNo,
9594
V = Store->getValueOperand()->stripPointerCasts();
9695
}
9796
const Constant *Init = cast<GlobalVariable>(V)->getInitializer();
98-
SymGlob = cast<GlobalVariable>(V);
9997
StringRef Res = cast<ConstantDataArray>(Init)->getAsString();
10098
if (Res.size() > 0 && Res[Res.size() - 1] == '\0')
10199
Res = Res.substr(0, Res.size() - 1);
@@ -104,16 +102,16 @@ StringRef getStringLiteralArg(const CallInst *CI, unsigned ArgNo,
104102

105103
// TODO support spec constant types other than integer or
106104
// floating-point.
107-
Value *genDefaultValue(Type *T, Instruction *At) {
105+
Value *getDefaultCPPValue(Type *T) {
108106
if (T->isIntegerTy())
109-
return ConstantInt::get(T, 0);
107+
return Constant::getIntegerValue(T, APInt(T->getScalarSizeInBits(), 0));
110108
if (T->isFloatingPointTy())
111109
return ConstantFP::get(T, 0.0);
112110
llvm_unreachable("non-numeric specialization constants are NYI");
113111
return nullptr;
114112
}
115113

116-
std::string manglePrimitiveType(Type *T) {
114+
std::string manglePrimitiveType(const Type *T) {
117115
if (T->isFloatTy())
118116
return "f";
119117
if (T->isDoubleTy())
@@ -139,7 +137,7 @@ std::string manglePrimitiveType(Type *T) {
139137

140138
// This is a very basic mangler which can mangle non-templated and non-member
141139
// functions with primitive types in the signature.
142-
std::string mangleFuncItanium(StringRef BaseName, FunctionType *FT) {
140+
std::string mangleFuncItanium(StringRef BaseName, const FunctionType *FT) {
143141
std::string Res =
144142
(Twine("_Z") + Twine(BaseName.size()) + Twine(BaseName)).str();
145143
for (unsigned I = 0; I < FT->getNumParams(); ++I)
@@ -156,7 +154,7 @@ void setSpecConstMetadata(Instruction *I, StringRef SymID, int IntID) {
156154
I->setMetadata(SPEC_CONST_SYM_ID_MD_STRING, Entry);
157155
}
158156

159-
std::pair<StringRef, unsigned> getSpecConstMetadata(Instruction *I) {
157+
std::pair<StringRef, unsigned> getSpecConstMetadata(const Instruction *I) {
160158
const MDNode *N = I->getMetadata(SPEC_CONST_SYM_ID_MD_STRING);
161159
if (!N)
162160
return std::make_pair("", 0);
@@ -167,13 +165,28 @@ std::pair<StringRef, unsigned> getSpecConstMetadata(Instruction *I) {
167165
return std::make_pair(MDSym->getString(), ID);
168166
}
169167

170-
static Value *getDefaultCPPValue(Type *T) {
171-
if (T->isIntegerTy())
172-
return Constant::getIntegerValue(T, APInt(T->getScalarSizeInBits(), 0));
173-
if (T->isFloatingPointTy())
174-
return ConstantFP::get(T, 0);
175-
llvm_unreachable("unsupported spec const type");
176-
return nullptr;
168+
Instruction *emitSpecConstant(int NumericID, Type *Ty,
169+
Instruction *InsertBefore) {
170+
Function *F = InsertBefore->getFunction();
171+
// Generate arguments needed by the SPIRV version of the intrinsic
172+
// - integer constant ID:
173+
Value *ID = ConstantInt::get(Type::getInt32Ty(F->getContext()), NumericID);
174+
// - default value:
175+
Value *Def = getDefaultCPPValue(Ty);
176+
// ... Now replace the call with SPIRV intrinsic version.
177+
Value *Args[] = {ID, Def};
178+
constexpr size_t NArgs = sizeof(Args) / sizeof(Args[0]);
179+
Type *ArgTys[NArgs] = {nullptr};
180+
for (unsigned int I = 0; I < NArgs; ++I)
181+
ArgTys[I] = Args[I]->getType();
182+
FunctionType *FT = FunctionType::get(Ty, ArgTys, false /*isVarArg*/);
183+
Module *M = F->getParent();
184+
std::string SPIRVName = mangleFuncItanium(SPIRV_GET_SPEC_CONST_VAL, FT);
185+
FunctionCallee FC = M->getOrInsertFunction(SPIRVName, FT);
186+
assert(FC.getCallee() && "SPIRV intrinsic creation failed");
187+
CallInst *SpecConstant =
188+
CallInst::Create(FT, FC.getCallee(), Args, "", InsertBefore);
189+
return SpecConstant;
177190
}
178191

179192
} // namespace
@@ -198,10 +211,8 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,
198211

199212
SmallVector<CallInst *, 32> SCIntrCalls;
200213
for (auto *U : F.users()) {
201-
auto *CI = dyn_cast<CallInst>(U);
202-
if (!CI)
203-
continue;
204-
SCIntrCalls.push_back(CI);
214+
if (auto *CI = dyn_cast<CallInst>(U))
215+
SCIntrCalls.push_back(CI);
205216
}
206217

207218
IRModified = IRModified || (SCIntrCalls.size() > 0);
@@ -213,8 +224,7 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,
213224
// code can't use this intrinsic directly.
214225
SmallVector<Instruction *, 3> DelInsts;
215226
DelInsts.push_back(CI);
216-
GlobalVariable *SymGlob = nullptr;
217-
StringRef SymID = getStringLiteralArg(CI, 0, DelInsts, SymGlob);
227+
StringRef SymID = getStringLiteralArg(CI, 0, DelInsts);
218228
Type *SCTy = CI->getType();
219229

220230
if (SetValAtRT) {
@@ -225,25 +235,7 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,
225235
if (Ins.second)
226236
Ins.first->second = NextID++;
227237
// 3. Transform to spirv intrinsic _Z*__spirv_SpecConstant*.
228-
LLVMContext &Ctx = F.getContext();
229-
// Generate arguments needed by the SPIRV version of the intrinsic
230-
// - integer constant ID:
231-
Value *ID = ConstantInt::get(Type::getInt32Ty(Ctx), NextID - 1);
232-
// - default value:
233-
Value *Def = genDefaultValue(SCTy, CI);
234-
// ... Now replace the call with SPIRV intrinsic version.
235-
Value *Args[] = {ID, Def};
236-
constexpr size_t NArgs = sizeof(Args) / sizeof(Args[0]);
237-
Type *ArgTys[NArgs] = {nullptr};
238-
for (unsigned int I = 0; I < NArgs; ++I)
239-
ArgTys[I] = Args[I]->getType();
240-
FunctionType *FT = FunctionType::get(SCTy, ArgTys, false /*isVarArg*/);
241-
Module &M = *F.getParent();
242-
std::string SPIRVName = mangleFuncItanium(SPIRV_GET_SPEC_CONST_VAL, FT);
243-
FunctionCallee FC = M.getOrInsertFunction(SPIRVName, FT);
244-
assert(FC.getCallee() && "SPIRV intrinsic creation failed");
245-
CallInst *SPIRVCall =
246-
CallInst::Create(FT, FC.getCallee(), Args, "", CI);
238+
auto *SPIRVCall = emitSpecConstant(NextID - 1, SCTy, CI);
247239
CI->replaceAllUsesWith(SPIRVCall);
248240
// Mark the instruction with <symbolic_id, int_id> pair for later
249241
// recollection by collectSpecConstantMetadata method.
@@ -261,8 +253,6 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,
261253
I->removeFromParent();
262254
I->deleteValue();
263255
}
264-
// Don't delete SymGlob here, as it may be referenced from multiple
265-
// functions if __sycl_getSpecConstantValue is inlined.
266256
}
267257
}
268258
return IRModified ? PreservedAnalyses::none() : PreservedAnalyses::all();

0 commit comments

Comments
 (0)