@@ -471,17 +471,10 @@ void OCLToSPIRVBase::visitCallNDRange(CallInst *CI, StringRef DemangledName) {
471
471
void OCLToSPIRVBase::visitCallAsyncWorkGroupCopy (CallInst *CI,
472
472
StringRef DemangledName) {
473
473
assert (CI->getCalledFunction () && " Unexpected indirect call" );
474
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
475
- mutateCallInstSPIRV (
476
- M, CI,
477
- [=](CallInst *, std::vector<Value *> &Args) {
478
- if (DemangledName == OCLUtil::kOCLBuiltinName ::AsyncWorkGroupCopy) {
479
- Args.insert (Args.begin () + 3 , addSizet (1 ));
480
- }
481
- Args.insert (Args.begin (), addInt32 (ScopeWorkgroup));
482
- return getSPIRVFuncName (OpGroupAsyncCopy);
483
- },
484
- &Attrs);
474
+ auto Mutator = mutateCallInst (CI, OpGroupAsyncCopy);
475
+ if (DemangledName == OCLUtil::kOCLBuiltinName ::AsyncWorkGroupCopy)
476
+ Mutator.insertArg (3 , addSizet (1 ));
477
+ Mutator.insertArg (0 , addInt32 (ScopeWorkgroup));
485
478
}
486
479
487
480
CallInst *OCLToSPIRVBase::visitCallAtomicCmpXchg (CallInst *CI) {
@@ -514,7 +507,6 @@ void OCLToSPIRVBase::visitCallAtomicInit(CallInst *CI) {
514
507
515
508
void OCLToSPIRVBase::visitCallAllAny (spv::Op OC, CallInst *CI) {
516
509
assert (CI->getCalledFunction () && " Unexpected indirect call" );
517
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
518
510
519
511
auto Args = getArguments (CI);
520
512
assert (Args.size () == 1 );
@@ -531,19 +523,10 @@ void OCLToSPIRVBase::visitCallAllAny(spv::Op OC, CallInst *CI) {
531
523
CI->replaceAllUsesWith (Cast);
532
524
CI->eraseFromParent ();
533
525
} else {
534
- mutateCallInstSPIRV (
535
- M, CI,
536
- [&](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
537
- Args[0 ] = Cmp;
538
- Ret = Type::getInt1Ty (*Ctx);
539
-
540
- return getSPIRVFuncName (OC);
541
- },
542
- [&](CallInst *CI) -> Instruction * {
543
- return CastInst::CreateZExtOrBitCast (CI, Type::getInt32Ty (*Ctx), " " ,
544
- CI->getNextNode ());
545
- },
546
- &Attrs);
526
+ mutateCallInst (CI, OC).setArgs ({Cmp}).changeReturnType (
527
+ Type::getInt32Ty (*Ctx), [](IRBuilder<> &Builder, CallInst *CI) {
528
+ return Builder.CreateZExtOrBitCast (CI, Builder.getInt32Ty ());
529
+ });
547
530
}
548
531
}
549
532
@@ -793,13 +776,7 @@ void OCLToSPIRVBase::visitCallConvert(CallInst *CI, StringRef MangledName,
793
776
Rounding = DemangledName.substr (Loc, 4 ).str ();
794
777
}
795
778
assert (CI->getCalledFunction () && " Unexpected indirect call" );
796
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
797
- mutateCallInstSPIRV (
798
- M, CI,
799
- [=](CallInst *, std::vector<Value *> &Args) {
800
- return getSPIRVFuncName (OC, TargetTyName + Sat + Rounding);
801
- },
802
- &Attrs);
779
+ mutateCallInst (CI, getSPIRVFuncName (OC, TargetTyName + Sat + Rounding));
803
780
}
804
781
805
782
void OCLToSPIRVBase::visitCallGroupBuiltin (CallInst *CI,
@@ -971,16 +948,10 @@ void OCLToSPIRVBase::transBuiltin(CallInst *CI, OCLBuiltinTransInfo &Info) {
971
948
void OCLToSPIRVBase::visitCallReadImageMSAA (CallInst *CI,
972
949
StringRef MangledName) {
973
950
assert (MangledName.find (" msaa" ) != StringRef::npos);
974
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
975
- mutateCallInstSPIRV (
976
- M, CI,
977
- [=](CallInst *, std::vector<Value *> &Args) {
978
- Args.insert (Args.begin () + 2 , getInt32 (M, ImageOperandsSampleMask));
979
- return getSPIRVFuncName (OpImageRead,
980
- std::string (kSPIRVPostfix ::ExtDivider) +
981
- getPostfixForReturnType (CI));
982
- },
983
- &Attrs);
951
+ mutateCallInst (
952
+ CI, getSPIRVFuncName (OpImageRead, std::string (kSPIRVPostfix ::ExtDivider) +
953
+ getPostfixForReturnType (CI)))
954
+ .insertArg (2 , getInt32 (M, ImageOperandsSampleMask));
984
955
}
985
956
986
957
void OCLToSPIRVBase::visitCallReadImageWithSampler (CallInst *CI,
@@ -1200,44 +1171,15 @@ void OCLToSPIRVBase::visitCallToAddr(CallInst *CI, StringRef DemangledName) {
1200
1171
void OCLToSPIRVBase::visitCallRelational (CallInst *CI,
1201
1172
StringRef DemangledName) {
1202
1173
assert (CI->getCalledFunction () && " Unexpected indirect call" );
1203
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
1204
1174
Op OC = OpNop;
1205
1175
OCLSPIRVBuiltinMap::find (DemangledName.str (), &OC);
1206
- std::string SPIRVName = getSPIRVFuncName (OC);
1207
- mutateCallInstSPIRV (
1208
- M, CI,
1209
- [=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
1210
- Ret = Type::getInt1Ty (*Ctx);
1211
- if (CI->getOperand (0 )->getType ()->isVectorTy ())
1212
- Ret = FixedVectorType::get (
1213
- Type::getInt1Ty (*Ctx),
1214
- cast<FixedVectorType>(CI->getOperand (0 )->getType ())
1215
- ->getNumElements ());
1216
- return SPIRVName;
1217
- },
1218
- [=](CallInst *NewCI) -> Instruction * {
1219
- Value *False = nullptr , *True = nullptr ;
1220
- if (NewCI->getType ()->isVectorTy ()) {
1221
- Type *IntTy = Type::getInt32Ty (*Ctx);
1222
- if (cast<FixedVectorType>(NewCI->getOperand (0 )->getType ())
1223
- ->getElementType ()
1224
- ->isDoubleTy ())
1225
- IntTy = Type::getInt64Ty (*Ctx);
1226
- if (cast<FixedVectorType>(NewCI->getOperand (0 )->getType ())
1227
- ->getElementType ()
1228
- ->isHalfTy ())
1229
- IntTy = Type::getInt16Ty (*Ctx);
1230
- Type *VTy = FixedVectorType::get (
1231
- IntTy, cast<FixedVectorType>(NewCI->getType ())->getNumElements ());
1232
- False = Constant::getNullValue (VTy);
1233
- True = Constant::getAllOnesValue (VTy);
1234
- } else {
1235
- False = getInt32 (M, 0 );
1236
- True = getInt32 (M, 1 );
1237
- }
1238
- return SelectInst::Create (NewCI, True, False, " " , NewCI->getNextNode ());
1239
- },
1240
- &Attrs);
1176
+ // i1 or <i1 x N>, depending on whether it returns a vector type.
1177
+ Type *BoolTy = CI->getType ()->getWithNewType (Type::getInt1Ty (*Ctx));
1178
+ mutateCallInst (CI, OC).changeReturnType (BoolTy, [=](IRBuilder<> &Builder,
1179
+ CallInst *NewCI) {
1180
+ return Builder.CreateSelect (NewCI, Constant::getAllOnesValue (CI->getType ()),
1181
+ Constant::getNullValue (CI->getType ()));
1182
+ });
1241
1183
}
1242
1184
1243
1185
void OCLToSPIRVBase::visitCallVecLoadStore (CallInst *CI, StringRef MangledName,
@@ -1464,19 +1406,12 @@ void OCLToSPIRVBase::visitCallGetImageChannel(CallInst *CI,
1464
1406
StringRef DemangledName,
1465
1407
unsigned int Offset) {
1466
1408
assert (CI->getCalledFunction () && " Unexpected indirect call" );
1467
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
1468
1409
Op OC = OpNop;
1469
1410
OCLSPIRVBuiltinMap::find (DemangledName.str (), &OC);
1470
- std::string SPIRVName = getSPIRVFuncName (OC);
1471
- mutateCallInstSPIRV (
1472
- M, CI,
1473
- [=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
1474
- return SPIRVName;
1475
- },
1476
- [=](CallInst *NewCI) -> Instruction * {
1477
- return BinaryOperator::CreateAdd (NewCI, getInt32 (M, Offset), " " , CI);
1478
- },
1479
- &Attrs);
1411
+ mutateCallInst (CI, OC).changeReturnType (
1412
+ CI->getType (), [=](IRBuilder<> &Builder, CallInst *NewCI) {
1413
+ return Builder.CreateAdd (NewCI, Builder.getInt32 (Offset));
1414
+ });
1480
1415
}
1481
1416
void OCLToSPIRVBase::visitCallEnqueueKernel (CallInst *CI,
1482
1417
StringRef DemangledName) {
@@ -1636,18 +1571,12 @@ void OCLToSPIRVBase::visitSubgroupBlockWriteINTEL(CallInst *CI) {
1636
1571
1637
1572
void OCLToSPIRVBase::visitSubgroupImageMediaBlockINTEL (
1638
1573
CallInst *CI, StringRef DemangledName) {
1639
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
1640
1574
spv::Op OpCode = DemangledName.rfind (" read" ) != StringRef::npos
1641
1575
? spv::OpSubgroupImageMediaBlockReadINTEL
1642
1576
: spv::OpSubgroupImageMediaBlockWriteINTEL;
1643
- mutateCallInstSPIRV (
1644
- M, CI,
1645
- [=](CallInst *, std::vector<Value *> &Args) {
1646
- // Moving the last argument to the beginning.
1647
- std::rotate (Args.begin (), Args.end () - 1 , Args.end ());
1648
- return getSPIRVFuncName (OpCode, CI->getType ());
1649
- },
1650
- &Attrs);
1577
+ // Move the last argument to the beginning.
1578
+ mutateCallInst (CI, getSPIRVFuncName (OpCode, CI->getType ()))
1579
+ .moveArg (CI->arg_size () - 1 , 0 );
1651
1580
}
1652
1581
1653
1582
static const char *getSubgroupAVCIntelOpKind (StringRef Name) {
@@ -1710,13 +1639,7 @@ void OCLToSPIRVBase::visitSubgroupAVCBuiltinCall(CallInst *CI,
1710
1639
return ;
1711
1640
}
1712
1641
1713
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
1714
- mutateCallInstSPIRV (
1715
- M, CI,
1716
- [=](CallInst *, std::vector<Value *> &Args) {
1717
- return getSPIRVFuncName (OC);
1718
- },
1719
- &Attrs);
1642
+ mutateCallInst (CI, OC);
1720
1643
}
1721
1644
1722
1645
// Handles Subgroup AVC Intel extension wrapper built-ins.
@@ -1939,13 +1862,7 @@ void OCLToSPIRVBase::visitCallConvertBFloat16AsUshort(CallInst *CI,
1939
1862
}
1940
1863
}
1941
1864
1942
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
1943
- mutateCallInstSPIRV (
1944
- M, CI,
1945
- [=](CallInst *, std::vector<Value *> &Args) {
1946
- return getSPIRVFuncName (internal::OpConvertFToBF16INTEL);
1947
- },
1948
- &Attrs);
1865
+ mutateCallInst (CI, internal::OpConvertFToBF16INTEL);
1949
1866
}
1950
1867
1951
1868
void OCLToSPIRVBase::visitCallConvertAsBFloat16Float (CallInst *CI,
@@ -1988,13 +1905,7 @@ void OCLToSPIRVBase::visitCallConvertAsBFloat16Float(CallInst *CI,
1988
1905
}
1989
1906
}
1990
1907
1991
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
1992
- mutateCallInstSPIRV (
1993
- M, CI,
1994
- [=](CallInst *, std::vector<Value *> &Args) {
1995
- return getSPIRVFuncName (internal::OpConvertBF16ToFINTEL);
1996
- },
1997
- &Attrs);
1908
+ mutateCallInst (CI, internal::OpConvertBF16ToFINTEL);
1998
1909
}
1999
1910
} // namespace SPIRV
2000
1911
0 commit comments