Skip to content

Commit 1aa3b17

Browse files
jcranmer-intelsvenvh
authored andcommitted
[Mutator] Migrate OCLBuiltinTransInfo::PostProc to the mutator interface.
This requires changing all of the callers of this method at the same time.
1 parent 4ac85bc commit 1aa3b17

File tree

3 files changed

+116
-133
lines changed

3 files changed

+116
-133
lines changed

lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 110 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -593,9 +593,9 @@ void OCLToSPIRVBase::visitCallAtomicLegacy(CallInst *CI, StringRef MangledName,
593593
PostOps.push_back(OCLLegacyAtomicMemOrder);
594594
PostOps.push_back(OCLLegacyAtomicMemScope);
595595

596-
Info.PostProc = [=](std::vector<Value *> &Ops) {
596+
Info.PostProc = [=](BuiltinCallMutator &Mutator) {
597597
for (auto &I : PostOps) {
598-
Ops.push_back(addInt32(I));
598+
Mutator.appendArg(addInt32(I));
599599
}
600600
};
601601
transAtomicBuiltin(CI, Info);
@@ -637,9 +637,9 @@ void OCLToSPIRVBase::visitCallAtomicCpp11(CallInst *CI, StringRef MangledName,
637637

638638
OCLBuiltinTransInfo Info;
639639
Info.UniqName = std::string("atomic_") + NewStem;
640-
Info.PostProc = [=](std::vector<Value *> &Ops) {
640+
Info.PostProc = [=](BuiltinCallMutator &Mutator) {
641641
for (auto &I : PostOps) {
642-
Ops.push_back(addInt32(I));
642+
Mutator.appendArg(addInt32(I));
643643
}
644644
};
645645

@@ -648,72 +648,65 @@ void OCLToSPIRVBase::visitCallAtomicCpp11(CallInst *CI, StringRef MangledName,
648648

649649
void OCLToSPIRVBase::transAtomicBuiltin(CallInst *CI,
650650
OCLBuiltinTransInfo &Info) {
651-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
652-
mutateCallInstSPIRV(
653-
M, CI,
654-
[=](CallInst *CI, std::vector<Value *> &Args) -> std::string {
655-
Info.PostProc(Args);
656-
// Order of args in OCL20:
657-
// object, 0-2 other args, 1-2 order, scope
658-
const size_t NumOrder =
659-
getAtomicBuiltinNumMemoryOrderArgs(Info.UniqName);
660-
const size_t ArgsCount = Args.size();
661-
const size_t ScopeIdx = ArgsCount - 1;
662-
const size_t OrderIdx = ScopeIdx - NumOrder;
663-
664-
Args[ScopeIdx] =
665-
transOCLMemScopeIntoSPIRVScope(Args[ScopeIdx], OCLMS_device, CI);
666-
667-
for (size_t I = 0; I < NumOrder; ++I) {
668-
Args[OrderIdx + I] = transOCLMemOrderIntoSPIRVMemorySemantics(
669-
Args[OrderIdx + I], OCLMO_seq_cst, CI);
670-
}
671-
// Order of args in SPIR-V:
672-
// object, scope, 1-2 order, 0-2 other args
673-
std::swap(Args[1], Args[ScopeIdx]);
674-
if (OrderIdx > 2) {
675-
// For atomic_compare_exchange the swap above puts Comparator/Expected
676-
// argument just where it should be, so don't move the last argument
677-
// then.
678-
int Offset =
679-
Info.UniqName.find("atomic_compare_exchange") == 0 ? 1 : 0;
680-
std::rotate(Args.begin() + 2, Args.begin() + OrderIdx,
681-
Args.end() - Offset);
682-
}
683-
llvm::Type *AtomicBuiltinsReturnType =
684-
CI->getCalledFunction()->getReturnType();
685-
auto IsFPType = [](llvm::Type *ReturnType) {
686-
return ReturnType->isHalfTy() || ReturnType->isFloatTy() ||
687-
ReturnType->isDoubleTy();
688-
};
689-
auto SPIRVFunctionName =
690-
getSPIRVFuncName(OCLSPIRVBuiltinMap::map(Info.UniqName));
691-
if (!IsFPType(AtomicBuiltinsReturnType))
692-
return SPIRVFunctionName;
693-
// Translate FP-typed atomic builtins. Currently we only need to
694-
// translate atomic_fetch_[add, sub, max, min] and atomic_fetch_[add,
695-
// sub, max, min]_explicit to related float instructions.
696-
// Translate atomic_fetch_sub to OpAtomicFAddEXT with negative value
697-
// operand
698-
auto SPIRFunctionNameForFloatAtomics =
699-
llvm::StringSwitch<std::string>(SPIRVFunctionName)
700-
.Case("__spirv_AtomicIAdd", "__spirv_AtomicFAddEXT")
701-
.Case("__spirv_AtomicISub", "__spirv_AtomicFAddEXT")
702-
.Case("__spirv_AtomicSMax", "__spirv_AtomicFMaxEXT")
703-
.Case("__spirv_AtomicSMin", "__spirv_AtomicFMinEXT")
704-
.Default("others");
705-
if (SPIRVFunctionName == "__spirv_AtomicISub") {
706-
IRBuilder<> IRB(CI);
707-
// Set float operand to its negation
708-
CI->setOperand(1, IRB.CreateFNeg(CI->getArgOperand(1)));
709-
// Update Args which is used to generate new call
710-
Args.back() = CI->getArgOperand(1);
711-
}
712-
return SPIRFunctionNameForFloatAtomics == "others"
713-
? SPIRVFunctionName
714-
: SPIRFunctionNameForFloatAtomics;
715-
},
716-
&Attrs);
651+
llvm::Type *AtomicBuiltinsReturnType = CI->getType();
652+
auto SPIRVFunctionName =
653+
getSPIRVFuncName(OCLSPIRVBuiltinMap::map(Info.UniqName));
654+
bool NeedsNegate = false;
655+
if (AtomicBuiltinsReturnType->isFloatingPointTy()) {
656+
// Translate FP-typed atomic builtins. Currently we only need to
657+
// translate atomic_fetch_[add, sub, max, min] and atomic_fetch_[add,
658+
// sub, max, min]_explicit to related float instructions.
659+
// Translate atomic_fetch_sub to OpAtomicFAddEXT with negative value
660+
// operand
661+
auto SPIRFunctionNameForFloatAtomics =
662+
llvm::StringSwitch<std::string>(SPIRVFunctionName)
663+
.Case("__spirv_AtomicIAdd", "__spirv_AtomicFAddEXT")
664+
.Case("__spirv_AtomicISub", "__spirv_AtomicFAddEXT")
665+
.Case("__spirv_AtomicSMax", "__spirv_AtomicFMaxEXT")
666+
.Case("__spirv_AtomicSMin", "__spirv_AtomicFMinEXT")
667+
.Default("others");
668+
if (SPIRVFunctionName == "__spirv_AtomicISub") {
669+
NeedsNegate = true;
670+
}
671+
if (SPIRFunctionNameForFloatAtomics != "others")
672+
SPIRVFunctionName = SPIRFunctionNameForFloatAtomics;
673+
}
674+
675+
auto Mutator = mutateCallInst(CI, SPIRVFunctionName);
676+
Info.PostProc(Mutator);
677+
// Order of args in OCL20:
678+
// object, 0-2 other args, 1-2 order, scope
679+
const size_t NumOrder = getAtomicBuiltinNumMemoryOrderArgs(Info.UniqName);
680+
const size_t ArgsCount = Mutator.arg_size();
681+
const size_t ScopeIdx = ArgsCount - 1;
682+
const size_t OrderIdx = ScopeIdx - NumOrder;
683+
684+
if (NeedsNegate) {
685+
Mutator.mapArg(1, [=](Value *V) {
686+
IRBuilder<> IRB(CI);
687+
return IRB.CreateFNeg(V);
688+
});
689+
}
690+
Mutator.mapArg(ScopeIdx, [=](Value *V) {
691+
return transOCLMemScopeIntoSPIRVScope(V, OCLMS_device, CI);
692+
});
693+
for (size_t I = 0; I < NumOrder; ++I) {
694+
Mutator.mapArg(OrderIdx + I, [=](Value *V) {
695+
return transOCLMemOrderIntoSPIRVMemorySemantics(V, OCLMO_seq_cst, CI);
696+
});
697+
}
698+
699+
// Order of args in SPIR-V:
700+
// object, scope, 1-2 order, 0-2 other args
701+
for (size_t I = 0; I < NumOrder; ++I) {
702+
Mutator.moveArg(OrderIdx + I, I + 1);
703+
}
704+
Mutator.moveArg(ScopeIdx, 1);
705+
if (Info.UniqName.find("atomic_compare_exchange") == 0) {
706+
// For atomic_compare_exchange, the two "other args" are in the opposite
707+
// order from the SPIR-V order. Swap these two arguments.
708+
Mutator.moveArg(Mutator.arg_size() - 1, Mutator.arg_size() - 2);
709+
}
717710
}
718711

719712
void OCLToSPIRVBase::visitCallBarrier(CallInst *CI) {
@@ -871,24 +864,29 @@ void OCLToSPIRVBase::visitCallGroupBuiltin(CallInst *CI,
871864
if (HasBoolReturnType)
872865
Info.RetTy = Type::getInt1Ty(*Ctx);
873866
Info.UniqName = DemangledName;
874-
Info.PostProc = [=](std::vector<Value *> &Ops) {
867+
Info.PostProc = [=](BuiltinCallMutator &Mutator) {
875868
if (HasBoolArg) {
876-
IRBuilder<> IRB(CI);
877-
Ops[0] =
878-
IRB.CreateICmpNE(Ops[0], ConstantInt::get(Type::getInt32Ty(*Ctx), 0));
869+
Mutator.mapArg(0, [&](Value *V) {
870+
IRBuilder<> IRB(CI);
871+
return IRB.CreateICmpNE(V, IRB.getInt32(0));
872+
});
879873
}
880-
size_t E = Ops.size();
874+
size_t E = Mutator.arg_size();
881875
if (DemangledName == "group_broadcast" && E > 2) {
882876
assert(E == 3 || E == 4);
877+
std::vector<Value *> Ops = getArguments(CI);
883878
makeVector(CI, Ops, std::make_pair(Ops.begin() + 1, Ops.end()));
879+
while (Mutator.arg_size() > 1)
880+
Mutator.removeArg(1);
881+
Mutator.appendArg(Ops.back());
884882
}
885-
Ops.insert(Ops.begin(), Consts.begin(), Consts.end());
883+
for (unsigned I = 0; I < Consts.size(); I++)
884+
Mutator.insertArg(I, Consts[I]);
886885
};
887886
transBuiltin(CI, Info);
888887
}
889888

890889
void OCLToSPIRVBase::transBuiltin(CallInst *CI, OCLBuiltinTransInfo &Info) {
891-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
892890
Op OC = OpNop;
893891
unsigned ExtOp = ~0U;
894892
SPIRVBuiltinVariableKind BVKind = BuiltInMax;
@@ -918,31 +916,18 @@ void OCLToSPIRVBase::transBuiltin(CallInst *CI, OCLBuiltinTransInfo &Info) {
918916
Info.UniqName = getSPIRVFuncName(BVKind);
919917
} else
920918
return;
921-
if (!Info.RetTy)
922-
mutateCallInstSPIRV(
923-
M, CI,
924-
[=](CallInst *, std::vector<Value *> &Args) {
925-
Info.PostProc(Args);
926-
return Info.UniqName + Info.Postfix;
927-
},
928-
&Attrs);
929-
else
930-
mutateCallInstSPIRV(
931-
M, CI,
932-
[=](CallInst *, std::vector<Value *> &Args, Type *&RetTy) {
933-
Info.PostProc(Args);
934-
RetTy = Info.RetTy;
935-
return Info.UniqName + Info.Postfix;
936-
},
937-
[=](CallInst *NewCI) -> Instruction * {
938-
if (NewCI->getType()->isIntegerTy() && CI->getType()->isIntegerTy())
939-
return CastInst::CreateIntegerCast(NewCI, CI->getType(),
940-
Info.IsRetSigned, "", CI);
919+
auto Mutator = mutateCallInst(CI, Info.UniqName + Info.Postfix);
920+
Info.PostProc(Mutator);
921+
if (Info.RetTy) {
922+
Type *OldRetTy = CI->getType();
923+
Mutator.changeReturnType(
924+
Info.RetTy, [&](IRBuilder<> &Builder, CallInst *NewCI) {
925+
if (Info.RetTy->isIntegerTy() && OldRetTy->isIntegerTy())
926+
return Builder.CreateIntCast(NewCI, OldRetTy, Info.IsRetSigned);
941927
else
942-
return CastInst::CreatePointerBitCastOrAddrSpaceCast(
943-
NewCI, CI->getType(), "", CI);
944-
},
945-
&Attrs);
928+
return Builder.CreatePointerBitCastOrAddrSpaceCast(NewCI, OldRetTy);
929+
});
930+
}
946931
}
947932

948933
void OCLToSPIRVBase::visitCallReadImageMSAA(CallInst *CI,
@@ -1122,27 +1107,25 @@ void OCLToSPIRVBase::visitCallReadWriteImage(CallInst *CI,
11221107
Info.UniqName = kOCLBuiltinName::ReadImage;
11231108
unsigned ImgOpMask = getImageSignZeroExt(DemangledName);
11241109
if (ImgOpMask) {
1125-
Info.PostProc = [&](std::vector<Value *> &Args) {
1126-
Args.push_back(getInt32(M, ImgOpMask));
1110+
Info.PostProc = [&](BuiltinCallMutator &Mutator) {
1111+
Mutator.appendArg(getInt32(M, ImgOpMask));
11271112
};
11281113
}
11291114
}
11301115

11311116
if (DemangledName.find(kOCLBuiltinName::WriteImage) == 0) {
11321117
Info.UniqName = kOCLBuiltinName::WriteImage;
1133-
Info.PostProc = [&](std::vector<Value *> &Args) {
1118+
Info.PostProc = [&](BuiltinCallMutator &Mutator) {
11341119
unsigned ImgOpMask = getImageSignZeroExt(DemangledName);
1135-
unsigned ImgOpMaskInsIndex = Args.size();
1136-
if (Args.size() == 4) // write with lod
1120+
unsigned ImgOpMaskInsIndex = Mutator.arg_size();
1121+
if (Mutator.arg_size() == 4) // write with lod
11371122
{
1138-
auto Lod = Args[2];
1139-
Args.erase(Args.begin() + 2);
11401123
ImgOpMask |= ImageOperandsMask::ImageOperandsLodMask;
1141-
ImgOpMaskInsIndex = Args.size();
1142-
Args.push_back(Lod);
1124+
ImgOpMaskInsIndex = Mutator.arg_size() - 1;
1125+
Mutator.moveArg(2, Mutator.arg_size() - 1);
11431126
}
11441127
if (ImgOpMask) {
1145-
Args.insert(Args.begin() + ImgOpMaskInsIndex, getInt32(M, ImgOpMask));
1128+
Mutator.insertArg(ImgOpMaskInsIndex, getInt32(M, ImgOpMask));
11461129
}
11471130
};
11481131
}
@@ -1159,11 +1142,14 @@ void OCLToSPIRVBase::visitCallToAddr(CallInst *CI, StringRef DemangledName) {
11591142
SPIRAddrSpaceCapitalizedNameMap::map(AddrSpace);
11601143
auto StorageClass = addInt32(SPIRSPIRVAddrSpaceMap::map(AddrSpace));
11611144
Info.RetTy = getInt8PtrTy(cast<PointerType>(CI->getType()));
1162-
Info.PostProc = [=](std::vector<Value *> &Ops) {
1163-
auto P = Ops.back();
1164-
Ops.pop_back();
1165-
Ops.push_back(castToInt8Ptr(P, CI));
1166-
Ops.push_back(StorageClass);
1145+
Info.PostProc = [=](BuiltinCallMutator &Mutator) {
1146+
Mutator
1147+
.mapArg(Mutator.arg_size() - 1,
1148+
[&](Value *V) {
1149+
return std::pair<Value *, Type *>(
1150+
castToInt8Ptr(V, CI), Type::getInt8Ty(V->getContext()));
1151+
})
1152+
.appendArg(StorageClass);
11671153
};
11681154
transBuiltin(CI, Info);
11691155
}
@@ -1216,8 +1202,9 @@ void OCLToSPIRVBase::visitCallVecLoadStore(CallInst *CI, StringRef MangledName,
12161202
if (DemangledName.find(kOCLBuiltinName::VLoadPrefix) == 0)
12171203
Info.Postfix =
12181204
std::string(kSPIRVPostfix::ExtDivider) + getPostfixForReturnType(CI);
1219-
Info.PostProc = [=](std::vector<Value *> &Ops) {
1220-
Ops.insert(Ops.end(), Consts.begin(), Consts.end());
1205+
Info.PostProc = [=](BuiltinCallMutator &Mutator) {
1206+
for (auto *Value : Consts)
1207+
Mutator.appendArg(Value);
12211208
};
12221209
transBuiltin(CI, Info);
12231210
}
@@ -1514,9 +1501,8 @@ void OCLToSPIRVBase::visitCallKernelQuery(CallInst *CI,
15141501

15151502
// Add postfix to overloaded intel subgroup block read/write builtins
15161503
// so new functions can be distinguished.
1517-
static void processSubgroupBlockReadWriteINTEL(CallInst *CI,
1518-
OCLBuiltinTransInfo &Info,
1519-
const Type *DataTy, Module *M) {
1504+
void OCLToSPIRVBase::processSubgroupBlockReadWriteINTEL(
1505+
CallInst *CI, OCLBuiltinTransInfo &Info, const Type *DataTy) {
15201506
unsigned VectorNumElements = 1;
15211507
if (auto *VecTy = dyn_cast<FixedVectorType>(DataTy))
15221508
VectorNumElements = VecTy->getNumElements();
@@ -1525,14 +1511,7 @@ static void processSubgroupBlockReadWriteINTEL(CallInst *CI,
15251511
Info.Postfix +=
15261512
getIntelSubgroupBlockDataPostfix(ElementBitSize, VectorNumElements);
15271513
assert(CI->getCalledFunction() && "Unexpected indirect call");
1528-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1529-
mutateCallInstSPIRV(
1530-
M, CI,
1531-
[&Info](CallInst *, std::vector<Value *> &Args) {
1532-
Info.PostProc(Args);
1533-
return Info.UniqName + Info.Postfix;
1534-
},
1535-
&Attrs);
1514+
mutateCallInst(CI, Info.UniqName + Info.Postfix);
15361515
}
15371516

15381517
// The intel_sub_group_block_read built-ins are overloaded to support both
@@ -1548,7 +1527,7 @@ void OCLToSPIRVBase::visitSubgroupBlockReadINTEL(CallInst *CI) {
15481527
else
15491528
Info.UniqName = getSPIRVFuncName(spv::OpSubgroupBlockReadINTEL);
15501529
Type *DataTy = CI->getType();
1551-
processSubgroupBlockReadWriteINTEL(CI, Info, DataTy, M);
1530+
processSubgroupBlockReadWriteINTEL(CI, Info, DataTy);
15521531
}
15531532

15541533
// The intel_sub_group_block_write built-ins are similarly overloaded to support
@@ -1566,7 +1545,7 @@ void OCLToSPIRVBase::visitSubgroupBlockWriteINTEL(CallInst *CI) {
15661545
"Intel subgroup block write should have arguments");
15671546
unsigned DataArg = CI->arg_size() - 1;
15681547
Type *DataTy = CI->getArgOperand(DataArg)->getType();
1569-
processSubgroupBlockReadWriteINTEL(CI, Info, DataTy, M);
1548+
processSubgroupBlockReadWriteINTEL(CI, Info, DataTy);
15701549
}
15711550

15721551
void OCLToSPIRVBase::visitSubgroupImageMediaBlockINTEL(

lib/SPIRV/OCLToSPIRV.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,10 @@ class OCLToSPIRVBase : public InstVisitor<OCLToSPIRVBase>, BuiltinCallHelper {
280280
/// Transform OpenCL vload/vstore function name.
281281
void transVecLoadStoreName(std::string &DemangledName,
282282
const std::string &Stem, bool AlwaysN);
283+
284+
void processSubgroupBlockReadWriteINTEL(CallInst *CI,
285+
OCLBuiltinTransInfo &Info,
286+
const Type *DataTy);
283287
};
284288

285289
class OCLToSPIRVLegacy : public OCLToSPIRVBase, public llvm::ModulePass {

lib/SPIRV/OCLUtil.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ struct OCLBuiltinTransInfo {
158158
std::string MangledName;
159159
std::string Postfix; // Postfix to be added
160160
/// Postprocessor of operands
161-
std::function<void(std::vector<Value *> &)> PostProc;
161+
std::function<void(BuiltinCallMutator &)> PostProc;
162162
Type *RetTy; // Return type of the translated function
163163
bool IsRetSigned; // When RetTy is int, determines if extensions
164164
// on it should be a sext or zet.
165165
OCLBuiltinTransInfo() : RetTy(nullptr), IsRetSigned(false) {
166-
PostProc = [](std::vector<Value *> &) {};
166+
PostProc = [](BuiltinCallMutator &) {};
167167
}
168168
};
169169

0 commit comments

Comments
 (0)