Skip to content

Commit 05b41a4

Browse files
jcranmer-intelbader
authored andcommitted
[Mutator] Migrate miscellaneous but simple calls to the mutator interface.
1 parent 6fdd2d0 commit 05b41a4

File tree

2 files changed

+84
-255
lines changed

2 files changed

+84
-255
lines changed

lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 30 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -471,17 +471,10 @@ void OCLToSPIRVBase::visitCallNDRange(CallInst *CI, StringRef DemangledName) {
471471
void OCLToSPIRVBase::visitCallAsyncWorkGroupCopy(CallInst *CI,
472472
StringRef DemangledName) {
473473
assert(CI->getCalledFunction() && "Unexpected indirect call");
474-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
475-
mutateCallInstSPIRV(
476-
M, CI,
477-
[=](CallInst *, std::vector<Value *> &Args) {
478-
if (DemangledName == OCLUtil::kOCLBuiltinName::AsyncWorkGroupCopy) {
479-
Args.insert(Args.begin() + 3, addSizet(1));
480-
}
481-
Args.insert(Args.begin(), addInt32(ScopeWorkgroup));
482-
return getSPIRVFuncName(OpGroupAsyncCopy);
483-
},
484-
&Attrs);
474+
auto Mutator = mutateCallInst(CI, OpGroupAsyncCopy);
475+
if (DemangledName == OCLUtil::kOCLBuiltinName::AsyncWorkGroupCopy)
476+
Mutator.insertArg(3, addSizet(1));
477+
Mutator.insertArg(0, addInt32(ScopeWorkgroup));
485478
}
486479

487480
CallInst *OCLToSPIRVBase::visitCallAtomicCmpXchg(CallInst *CI) {
@@ -514,7 +507,6 @@ void OCLToSPIRVBase::visitCallAtomicInit(CallInst *CI) {
514507

515508
void OCLToSPIRVBase::visitCallAllAny(spv::Op OC, CallInst *CI) {
516509
assert(CI->getCalledFunction() && "Unexpected indirect call");
517-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
518510

519511
auto Args = getArguments(CI);
520512
assert(Args.size() == 1);
@@ -531,19 +523,10 @@ void OCLToSPIRVBase::visitCallAllAny(spv::Op OC, CallInst *CI) {
531523
CI->replaceAllUsesWith(Cast);
532524
CI->eraseFromParent();
533525
} else {
534-
mutateCallInstSPIRV(
535-
M, CI,
536-
[&](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
537-
Args[0] = Cmp;
538-
Ret = Type::getInt1Ty(*Ctx);
539-
540-
return getSPIRVFuncName(OC);
541-
},
542-
[&](CallInst *CI) -> Instruction * {
543-
return CastInst::CreateZExtOrBitCast(CI, Type::getInt32Ty(*Ctx), "",
544-
CI->getNextNode());
545-
},
546-
&Attrs);
526+
mutateCallInst(CI, OC).setArgs({Cmp}).changeReturnType(
527+
Type::getInt32Ty(*Ctx), [](IRBuilder<> &Builder, CallInst *CI) {
528+
return Builder.CreateZExtOrBitCast(CI, Builder.getInt32Ty());
529+
});
547530
}
548531
}
549532

@@ -793,13 +776,7 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
793776
Rounding = DemangledName.substr(Loc, 4).str();
794777
}
795778
assert(CI->getCalledFunction() && "Unexpected indirect call");
796-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
797-
mutateCallInstSPIRV(
798-
M, CI,
799-
[=](CallInst *, std::vector<Value *> &Args) {
800-
return getSPIRVFuncName(OC, TargetTyName + Sat + Rounding);
801-
},
802-
&Attrs);
779+
mutateCallInst(CI, getSPIRVFuncName(OC, TargetTyName + Sat + Rounding));
803780
}
804781

805782
void OCLToSPIRVBase::visitCallGroupBuiltin(CallInst *CI,
@@ -971,16 +948,10 @@ void OCLToSPIRVBase::transBuiltin(CallInst *CI, OCLBuiltinTransInfo &Info) {
971948
void OCLToSPIRVBase::visitCallReadImageMSAA(CallInst *CI,
972949
StringRef MangledName) {
973950
assert(MangledName.find("msaa") != StringRef::npos);
974-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
975-
mutateCallInstSPIRV(
976-
M, CI,
977-
[=](CallInst *, std::vector<Value *> &Args) {
978-
Args.insert(Args.begin() + 2, getInt32(M, ImageOperandsSampleMask));
979-
return getSPIRVFuncName(OpImageRead,
980-
std::string(kSPIRVPostfix::ExtDivider) +
981-
getPostfixForReturnType(CI));
982-
},
983-
&Attrs);
951+
mutateCallInst(
952+
CI, getSPIRVFuncName(OpImageRead, std::string(kSPIRVPostfix::ExtDivider) +
953+
getPostfixForReturnType(CI)))
954+
.insertArg(2, getInt32(M, ImageOperandsSampleMask));
984955
}
985956

986957
void OCLToSPIRVBase::visitCallReadImageWithSampler(CallInst *CI,
@@ -1200,44 +1171,15 @@ void OCLToSPIRVBase::visitCallToAddr(CallInst *CI, StringRef DemangledName) {
12001171
void OCLToSPIRVBase::visitCallRelational(CallInst *CI,
12011172
StringRef DemangledName) {
12021173
assert(CI->getCalledFunction() && "Unexpected indirect call");
1203-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
12041174
Op OC = OpNop;
12051175
OCLSPIRVBuiltinMap::find(DemangledName.str(), &OC);
1206-
std::string SPIRVName = getSPIRVFuncName(OC);
1207-
mutateCallInstSPIRV(
1208-
M, CI,
1209-
[=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
1210-
Ret = Type::getInt1Ty(*Ctx);
1211-
if (CI->getOperand(0)->getType()->isVectorTy())
1212-
Ret = FixedVectorType::get(
1213-
Type::getInt1Ty(*Ctx),
1214-
cast<FixedVectorType>(CI->getOperand(0)->getType())
1215-
->getNumElements());
1216-
return SPIRVName;
1217-
},
1218-
[=](CallInst *NewCI) -> Instruction * {
1219-
Value *False = nullptr, *True = nullptr;
1220-
if (NewCI->getType()->isVectorTy()) {
1221-
Type *IntTy = Type::getInt32Ty(*Ctx);
1222-
if (cast<FixedVectorType>(NewCI->getOperand(0)->getType())
1223-
->getElementType()
1224-
->isDoubleTy())
1225-
IntTy = Type::getInt64Ty(*Ctx);
1226-
if (cast<FixedVectorType>(NewCI->getOperand(0)->getType())
1227-
->getElementType()
1228-
->isHalfTy())
1229-
IntTy = Type::getInt16Ty(*Ctx);
1230-
Type *VTy = FixedVectorType::get(
1231-
IntTy, cast<FixedVectorType>(NewCI->getType())->getNumElements());
1232-
False = Constant::getNullValue(VTy);
1233-
True = Constant::getAllOnesValue(VTy);
1234-
} else {
1235-
False = getInt32(M, 0);
1236-
True = getInt32(M, 1);
1237-
}
1238-
return SelectInst::Create(NewCI, True, False, "", NewCI->getNextNode());
1239-
},
1240-
&Attrs);
1176+
// i1 or <i1 x N>, depending on whether it returns a vector type.
1177+
Type *BoolTy = CI->getType()->getWithNewType(Type::getInt1Ty(*Ctx));
1178+
mutateCallInst(CI, OC).changeReturnType(BoolTy, [=](IRBuilder<> &Builder,
1179+
CallInst *NewCI) {
1180+
return Builder.CreateSelect(NewCI, Constant::getAllOnesValue(CI->getType()),
1181+
Constant::getNullValue(CI->getType()));
1182+
});
12411183
}
12421184

12431185
void OCLToSPIRVBase::visitCallVecLoadStore(CallInst *CI, StringRef MangledName,
@@ -1464,19 +1406,12 @@ void OCLToSPIRVBase::visitCallGetImageChannel(CallInst *CI,
14641406
StringRef DemangledName,
14651407
unsigned int Offset) {
14661408
assert(CI->getCalledFunction() && "Unexpected indirect call");
1467-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
14681409
Op OC = OpNop;
14691410
OCLSPIRVBuiltinMap::find(DemangledName.str(), &OC);
1470-
std::string SPIRVName = getSPIRVFuncName(OC);
1471-
mutateCallInstSPIRV(
1472-
M, CI,
1473-
[=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
1474-
return SPIRVName;
1475-
},
1476-
[=](CallInst *NewCI) -> Instruction * {
1477-
return BinaryOperator::CreateAdd(NewCI, getInt32(M, Offset), "", CI);
1478-
},
1479-
&Attrs);
1411+
mutateCallInst(CI, OC).changeReturnType(
1412+
CI->getType(), [=](IRBuilder<> &Builder, CallInst *NewCI) {
1413+
return Builder.CreateAdd(NewCI, Builder.getInt32(Offset));
1414+
});
14801415
}
14811416
void OCLToSPIRVBase::visitCallEnqueueKernel(CallInst *CI,
14821417
StringRef DemangledName) {
@@ -1636,18 +1571,12 @@ void OCLToSPIRVBase::visitSubgroupBlockWriteINTEL(CallInst *CI) {
16361571

16371572
void OCLToSPIRVBase::visitSubgroupImageMediaBlockINTEL(
16381573
CallInst *CI, StringRef DemangledName) {
1639-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
16401574
spv::Op OpCode = DemangledName.rfind("read") != StringRef::npos
16411575
? spv::OpSubgroupImageMediaBlockReadINTEL
16421576
: spv::OpSubgroupImageMediaBlockWriteINTEL;
1643-
mutateCallInstSPIRV(
1644-
M, CI,
1645-
[=](CallInst *, std::vector<Value *> &Args) {
1646-
// Moving the last argument to the beginning.
1647-
std::rotate(Args.begin(), Args.end() - 1, Args.end());
1648-
return getSPIRVFuncName(OpCode, CI->getType());
1649-
},
1650-
&Attrs);
1577+
// Move the last argument to the beginning.
1578+
mutateCallInst(CI, getSPIRVFuncName(OpCode, CI->getType()))
1579+
.moveArg(CI->arg_size() - 1, 0);
16511580
}
16521581

16531582
static const char *getSubgroupAVCIntelOpKind(StringRef Name) {
@@ -1710,13 +1639,7 @@ void OCLToSPIRVBase::visitSubgroupAVCBuiltinCall(CallInst *CI,
17101639
return;
17111640
}
17121641

1713-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1714-
mutateCallInstSPIRV(
1715-
M, CI,
1716-
[=](CallInst *, std::vector<Value *> &Args) {
1717-
return getSPIRVFuncName(OC);
1718-
},
1719-
&Attrs);
1642+
mutateCallInst(CI, OC);
17201643
}
17211644

17221645
// Handles Subgroup AVC Intel extension wrapper built-ins.
@@ -1939,13 +1862,7 @@ void OCLToSPIRVBase::visitCallConvertBFloat16AsUshort(CallInst *CI,
19391862
}
19401863
}
19411864

1942-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1943-
mutateCallInstSPIRV(
1944-
M, CI,
1945-
[=](CallInst *, std::vector<Value *> &Args) {
1946-
return getSPIRVFuncName(internal::OpConvertFToBF16INTEL);
1947-
},
1948-
&Attrs);
1865+
mutateCallInst(CI, internal::OpConvertFToBF16INTEL);
19491866
}
19501867

19511868
void OCLToSPIRVBase::visitCallConvertAsBFloat16Float(CallInst *CI,
@@ -1988,13 +1905,7 @@ void OCLToSPIRVBase::visitCallConvertAsBFloat16Float(CallInst *CI,
19881905
}
19891906
}
19901907

1991-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1992-
mutateCallInstSPIRV(
1993-
M, CI,
1994-
[=](CallInst *, std::vector<Value *> &Args) {
1995-
return getSPIRVFuncName(internal::OpConvertBF16ToFINTEL);
1996-
},
1997-
&Attrs);
1908+
mutateCallInst(CI, internal::OpConvertBF16ToFINTEL);
19981909
}
19991910
} // namespace SPIRV
20001911

0 commit comments

Comments
 (0)