Skip to content

Commit ac3a4af

Browse files
committed
Add test for ESIMDLowerVecArg pass. Refactor code as per code comments.
Signed-off-by: Ashar, Pratik J <[email protected]>
1 parent fb752d6 commit ac3a4af

File tree

6 files changed

+319
-145
lines changed

6 files changed

+319
-145
lines changed

llvm/include/llvm/InitializePasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ void initializeStructurizeCFGPass(PassRegistry&);
416416
void initializeSYCLLowerWGScopeLegacyPassPass(PassRegistry &);
417417
void initializeSYCLLowerESIMDLegacyPassPass(PassRegistry &);
418418
void initializeESIMDLowerLoadStorePass(PassRegistry &);
419+
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry &);
419420
void initializeTailCallElimPass(PassRegistry&);
420421
void initializeTailDuplicatePass(PassRegistry&);
421422
void initializeTargetLibraryInfoWrapperPassPass(PassRegistry&);

llvm/include/llvm/LinkAllPasses.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ namespace {
204204
(void)llvm::createSYCLLowerWGScopePass();
205205
(void)llvm::createSYCLLowerESIMDPass();
206206
(void)llvm::createESIMDLowerLoadStorePass();
207+
(void)llvm::createESIMDLowerVecArgPass();
207208
std::string buf;
208209
llvm::raw_string_ostream os(buf);
209210
(void) llvm::createPrintModulePass(os);

llvm/lib/SYCLLowerIR/LowerESIMDVecArg.cpp

Lines changed: 43 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
//===----------------------------------------------------------------------===//
6666

6767
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
68+
#include "llvm/Transforms/Utils/Cloning.h"
6869

6970
using namespace llvm;
7071

@@ -73,7 +74,7 @@ using namespace llvm;
7374
namespace llvm {
7475

7576
// Forward declarations
76-
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry&);
77+
void initializeESIMDLowerVecArgLegacyPassPass(PassRegistry &);
7778
ModulePass *createESIMDLowerVecArgPass();
7879

7980
// Pass converts simd* function parameters and globals to
@@ -86,9 +87,8 @@ class ESIMDLowerVecArgPass {
8687
DenseMap<GlobalVariable *, GlobalVariable *> OldNewGlobal;
8788

8889
Function *rewriteFunc(Function &F);
89-
Type *argIsSimdPtr(Value *arg);
90+
Type *getSimdArgPtrTyOrNull(Value *arg);
9091
void fixGlobals(Module &M);
91-
bool hasGlobalConstExpr(Value *V, GlobalVariable *&Global);
9292
void replaceConstExprWithGlobals(Module &M);
9393
ConstantExpr *createNewConstantExpr(GlobalVariable *newGlobalVar,
9494
Type *oldGlobalType, Value *old);
@@ -128,14 +128,13 @@ ModulePass *llvm::createESIMDLowerVecArgPass() {
128128

129129
// Return ptr to first-class vector type if Value is a simd*, else return
130130
// nullptr.
131-
Type *ESIMDLowerVecArgPass::argIsSimdPtr(Value *arg) {
132-
auto ArgType = arg->getType();
133-
if (ArgType->isPointerTy()) {
134-
auto containedType = ArgType->getPointerElementType();
135-
if (containedType->isStructTy()) {
136-
if (containedType->getStructNumElements() == 1 &&
137-
containedType->getStructElementType(0)->isVectorTy()) {
138-
return PointerType::get(containedType->getStructElementType(0),
131+
Type *ESIMDLowerVecArgPass::getSimdArgPtrTyOrNull(Value *arg) {
132+
if (auto ArgType = dyn_cast<PointerType>(arg->getType())) {
133+
auto ContainedType = ArgType->getElementType();
134+
if (ContainedType->isStructTy()) {
135+
if (ContainedType->getStructNumElements() == 1 &&
136+
ContainedType->getStructElementType(0)->isVectorTy()) {
137+
return PointerType::get(ContainedType->getStructElementType(0),
139138
ArgType->getPointerAddressSpace());
140139
}
141140
}
@@ -150,83 +149,49 @@ Function *ESIMDLowerVecArgPass::rewriteFunc(Function &F) {
150149
FunctionType *FTy = F.getFunctionType();
151150
Type *RetTy = FTy->getReturnType();
152151
SmallVector<Type *, 8> ArgTys;
153-
AttributeList AttrVec;
154-
const AttributeList &PAL = F.getAttributes();
155-
// Argument, result of load
156-
DenseMap<Argument *, Value *> ToModify;
157-
auto &Context = F.getContext();
158152

159153
for (unsigned int i = 0; i != F.arg_size(); i++) {
160154
auto Arg = F.getArg(i);
161-
Type *NewTy = argIsSimdPtr(Arg);
155+
Type *NewTy = getSimdArgPtrTyOrNull(Arg);
162156
if (NewTy) {
163157
// Copy over byval type for simd* type
164158
ArgTys.push_back(NewTy);
165159
} else {
166160
// Transfer all non-simd ptr arguments
167161
ArgTys.push_back(Arg->getType());
168-
AttributeSet Attrs = PAL.getParamAttributes(i);
169-
if (Attrs.hasAttributes()) {
170-
AttrBuilder B(Attrs);
171-
AttrVec = AttrVec.addParamAttributes(Context, i, B);
172-
}
173162
}
174163
}
175164

176165
FunctionType *NFTy = FunctionType::get(RetTy, ArgTys, false);
177166

178-
// Add any function attributes
179-
AttributeSet FnAttrs = PAL.getFnAttributes();
180-
if (FnAttrs.hasAttributes()) {
181-
AttrBuilder B(FnAttrs);
182-
AttrVec = AttrVec.addAttributes(Context, AttributeList::FunctionIndex, B);
183-
}
184-
185-
auto RetAttrs = PAL.getRetAttributes();
186-
if (RetAttrs.hasAttributes()) {
187-
AttrBuilder B(RetAttrs);
188-
AttrVec = AttrVec.addAttributes(Context, AttributeList::ReturnIndex, B);
189-
}
190-
191167
// Create new function body and insert into the module
192168
Function *NF = Function::Create(NFTy, F.getLinkage(), F.getName());
193-
NF->copyAttributesFrom(&F);
194-
NF->setCallingConv(F.getCallingConv());
195-
196169
F.getParent()->getFunctionList().insert(F.getIterator(), NF);
197-
NF->takeName(&F);
198-
NF->setSubprogram(F.getSubprogram());
199-
200-
// Now to splice the body of the old function into the new function
201-
NF->getBasicBlockList().splice(NF->begin(), F.getBasicBlockList());
202170

171+
SmallVector<ReturnInst *, 8> Returns;
172+
SmallVector<BitCastInst *, 8> BitCasts;
173+
ValueToValueMapTy VMap;
203174
for (unsigned int I = 0; I != F.arg_size(); I++) {
204175
auto Arg = F.getArg(I);
205-
Type *newTy = argIsSimdPtr(Arg);
176+
Type *newTy = getSimdArgPtrTyOrNull(Arg);
206177
if (newTy) {
207-
// Insert bitcast
208178
// bitcast vector* -> simd*
209179
auto BitCast = new BitCastInst(NF->getArg(I), Arg->getType());
210-
NF->begin()->getInstList().push_front(BitCast);
211-
ToModify.insert(std::make_pair(Arg, nullptr));
212-
Arg->replaceAllUsesWith(BitCast);
180+
BitCasts.push_back(BitCast);
181+
VMap.insert(std::make_pair(Arg, BitCast));
182+
continue;
213183
}
184+
VMap.insert(std::make_pair(Arg, NF->getArg(I)));
214185
}
215186

216-
// Loop over the argument list, transferring uses of the old arguments to the
217-
// new arguments, also tranferring over the names as well
218-
Function::arg_iterator I2 = NF->arg_begin();
219-
unsigned int ArgNo = 0;
220-
for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E;
221-
++I, ++I2, ArgNo++) {
222-
auto ArgIt = ToModify.find(I);
223-
if (ArgIt == ToModify.end()) {
224-
// Transfer old arguments as is
225-
I->replaceAllUsesWith(I2);
226-
I2->takeName(I);
227-
}
187+
llvm::CloneFunctionInto(NF, &F, VMap, F.getSubprogram() != nullptr, Returns);
188+
189+
for (auto &B : BitCasts) {
190+
NF->begin()->getInstList().push_front(B);
228191
}
229192

193+
NF->takeName(&F);
194+
230195
// Fix call sites
231196
SmallVector<std::pair<Instruction *, Instruction *>, 10> OldNewInst;
232197
for (auto &use : F.uses()) {
@@ -235,11 +200,14 @@ Function *ESIMDLowerVecArgPass::rewriteFunc(Function &F) {
235200
auto User = use.getUser();
236201
if (isa<CallInst>(User)) {
237202
auto Call = cast<CallInst>(User);
203+
// Variadic functions not supported
204+
assert(!Call->getFunction()->isVarArg() &&
205+
"Variadic functions not supported");
238206
for (unsigned int I = 0,
239207
NumOpnds = cast<CallInst>(Call)->getNumArgOperands();
240208
I != NumOpnds; I++) {
241209
auto SrcOpnd = Call->getOperand(I);
242-
auto NewTy = argIsSimdPtr(SrcOpnd);
210+
auto NewTy = getSimdArgPtrTyOrNull(SrcOpnd);
243211
if (NewTy) {
244212
auto BitCast = new BitCastInst(SrcOpnd, NewTy, "", Call);
245213
Params.push_back(BitCast);
@@ -269,24 +237,10 @@ Function *ESIMDLowerVecArgPass::rewriteFunc(Function &F) {
269237
return NF;
270238
}
271239

272-
bool ESIMDLowerVecArgPass::hasGlobalConstExpr(Value *V, GlobalVariable *&Global) {
273-
if (isa<GlobalVariable>(V)) {
274-
Global = cast<GlobalVariable>(V);
275-
return true;
276-
}
277-
278-
if (isa<ConstantExpr>(V)) {
279-
auto FirstOpnd = cast<ConstantExpr>(V)->getOperand(0);
280-
return hasGlobalConstExpr(FirstOpnd, Global);
281-
}
282-
283-
return false;
284-
}
285-
286240
// Replace ConstantExpr if it contains old global variable.
287241
ConstantExpr *
288242
ESIMDLowerVecArgPass::createNewConstantExpr(GlobalVariable *NewGlobalVar,
289-
Type *OldGlobalType, Value *Old) {
243+
Type *OldGlobalType, Value *Old) {
290244
ConstantExpr *NewConstantExpr = nullptr;
291245

292246
if (isa<GlobalVariable>(Old)) {
@@ -308,64 +262,12 @@ ESIMDLowerVecArgPass::createNewConstantExpr(GlobalVariable *NewGlobalVar,
308262
// all such instances and replaces them with a new ConstantExpr
309263
// consisting of new global vector* variable.
310264
void ESIMDLowerVecArgPass::replaceConstExprWithGlobals(Module &M) {
311-
for (auto &F : M) {
312-
for (auto BB = F.begin(), BBEnd = F.end(); BB != BBEnd; ++BB) {
313-
DenseMap<Instruction *, Instruction *> OldNewInst;
314-
for (auto OI = BB->begin(), OE = BB->end(); OI != OE; ++OI) {
315-
SmallVector<Value *, 6> Operands;
316-
bool HasGlobals = false;
317-
auto &Inst = (*OI);
318-
for (unsigned int OP = 0, OPE = Inst.getNumOperands(); OP != OPE;
319-
++OP) {
320-
auto opnd = Inst.getOperand(OP);
321-
if (!isa<ConstantExpr>(opnd)) {
322-
Operands.push_back(opnd);
323-
continue;
324-
}
325-
GlobalVariable *OldGlobal = nullptr;
326-
auto OldGlobalVar = hasGlobalConstExpr(opnd, OldGlobal);
327-
if (OldGlobalVar && OldNewGlobal.find(OldGlobal) != OldNewGlobal.end()) {
328-
HasGlobals = true;
329-
auto NewGlobal = OldNewGlobal[OldGlobal];
330-
assert(NewGlobal && "Didnt find new global");
331-
Operands.push_back(
332-
createNewConstantExpr(NewGlobal, OldGlobal->getType(), opnd));
333-
} else {
334-
Operands.push_back(opnd);
335-
}
336-
}
337-
338-
if (HasGlobals) {
339-
Instruction *NewInst = nullptr;
340-
if (isa<CallInst>(&Inst)) {
341-
assert(isa<CallInst>(&Inst) && "Expecting call instruction");
342-
// pop last parameter which is function declaration
343-
auto CallI = cast<CallInst>(&Inst);
344-
Operands.pop_back();
345-
NewInst =
346-
CallInst::Create(CallI->getFunctionType(),
347-
CallI->getCalledFunction(), Operands, "");
348-
cast<CallInst>(NewInst)->setTailCallKind(CallI->getTailCallKind());
349-
} else if (isa<StoreInst>(&Inst)) {
350-
auto StoreI = cast<StoreInst>(&Inst);
351-
NewInst = new StoreInst(Operands[0], Operands[1],
352-
StoreI->isVolatile(), StoreI->getAlign());
353-
} else if (isa<LoadInst>(&Inst)) {
354-
auto LoadI = cast<LoadInst>(&Inst);
355-
NewInst = new LoadInst(Inst.getType(), Operands[0],
356-
LoadI->getName(), LoadI->isVolatile(),
357-
LoadI->getAlign());
358-
} else {
359-
assert(false && "Not expecting this instruction with global");
360-
}
361-
OldNewInst[&Inst] = NewInst;
362-
NewInst->copyMetadata(Inst);
363-
}
364-
}
365-
366-
for (auto Replace : OldNewInst) {
367-
ReplaceInstWithInst(Replace.first, Replace.second);
368-
}
265+
for (auto &GlobalVars : OldNewGlobal) {
266+
auto &G = *GlobalVars.first;
267+
for (auto UseOfG : G.users()) {
268+
auto NewGlobal = GlobalVars.second;
269+
auto NewConstExpr = createNewConstantExpr(NewGlobal, G.getType(), UseOfG);
270+
UseOfG->replaceAllUsesWith(NewConstExpr);
369271
}
370272
}
371273
}
@@ -374,23 +276,19 @@ void ESIMDLowerVecArgPass::replaceConstExprWithGlobals(Module &M) {
374276
// when old one is of simd* type.
375277
void ESIMDLowerVecArgPass::fixGlobals(Module &M) {
376278
for (auto &G : M.getGlobalList()) {
377-
auto NewTy = argIsSimdPtr(&G);
279+
auto NewTy = getSimdArgPtrTyOrNull(&G);
378280
if (NewTy && !G.user_empty()) {
379-
// Peel off ptr type that argIsSimdPtr applies
281+
// Peel off ptr type that getSimdArgPtrTyOrNull applies
380282
NewTy = NewTy->getPointerElementType();
381-
auto ZeroInit = new APInt(32, 0);
283+
auto ZeroInit = ConstantAggregateZero::get(NewTy);
382284
auto NewGlobalVar =
383-
new GlobalVariable(NewTy, G.isConstant(), G.getLinkage(),
384-
Constant::getIntegerValue(NewTy, *ZeroInit));
285+
new GlobalVariable(NewTy, G.isConstant(), G.getLinkage(), ZeroInit,
286+
"", G.getThreadLocalMode(), G.getAddressSpace());
385287
NewGlobalVar->setExternallyInitialized(G.isExternallyInitialized());
386288
NewGlobalVar->copyAttributesFrom(&G);
387289
NewGlobalVar->takeName(&G);
290+
NewGlobalVar->copyMetadata(&G, 0);
388291
M.getGlobalList().push_back(NewGlobalVar);
389-
SmallVector<DIGlobalVariableExpression *, 5> GVs;
390-
G.getDebugInfo(GVs);
391-
for (auto md : GVs) {
392-
NewGlobalVar->addDebugInfo(md);
393-
}
394292
OldNewGlobal.insert(std::make_pair(&G, NewGlobalVar));
395293
}
396294
}
@@ -419,7 +317,7 @@ bool ESIMDLowerVecArgPass::run(Module &M) {
419317
for (auto F : functions) {
420318
for (unsigned int I = 0; I != F->arg_size(); I++) {
421319
auto Arg = F->getArg(I);
422-
if (argIsSimdPtr(Arg)) {
320+
if (getSimdArgPtrTyOrNull(Arg)) {
423321
rewriteFunc(*F);
424322
break;
425323
}

0 commit comments

Comments
 (0)