Skip to content

Commit dfc3e2d

Browse files
[Mutator] Migrate handling of the AVC builtins to the mutator interface. (#1623)
1 parent a33e094 commit dfc3e2d

File tree

2 files changed

+120
-137
lines changed

2 files changed

+120
-137
lines changed

lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 57 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1595,7 +1595,6 @@ void OCLToSPIRVBase::visitSubgroupAVCBuiltinCall(CallInst *CI,
15951595
// conterpart from 'MCE' with conversion for an argument and result (if needed).
15961596
void OCLToSPIRVBase::visitSubgroupAVCWrapperBuiltinCall(
15971597
CallInst *CI, Op WrappedOC, StringRef DemangledName) {
1598-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
15991598
std::string Prefix = kOCLSubgroupsAVCIntel::Prefix;
16001599

16011600
// Find 'to_mce' conversion function.
@@ -1612,9 +1611,6 @@ void OCLToSPIRVBase::visitSubgroupAVCWrapperBuiltinCall(
16121611
OCLSPIRVSubgroupAVCIntelBuiltinMap::find(ToMCEFName, &ToMCEOC);
16131612
assert(ToMCEOC != OpNop && "Invalid Subgroup AVC Intel built-in call");
16141613

1615-
SmallVector<Type *, 2> ParamTys;
1616-
getParameterTypes(CI, ParamTys);
1617-
16181614
if (std::strcmp(TyKind, "payload") == 0) {
16191615
// Wrapper built-ins which take the 'payload_t' argument return it as
16201616
// the result: two conversion calls required.
@@ -1624,38 +1620,31 @@ void OCLToSPIRVBase::visitSubgroupAVCWrapperBuiltinCall(
16241620
OCLSPIRVSubgroupAVCIntelBuiltinMap::find(FromMCEFName, &FromMCEOC);
16251621
assert(FromMCEOC != OpNop && "Invalid Subgroup AVC Intel built-in call");
16261622

1627-
mutateCallInstSPIRV(
1628-
M, CI,
1629-
[=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
1630-
Ret = MCETy;
1631-
// Create conversion function call for the last operand
1632-
Args[Args.size() - 1] = addCallInstSPIRV(
1633-
M, getSPIRVFuncName(ToMCEOC), MCETy, Args[Args.size() - 1],
1634-
nullptr, {ParamTys[Args.size() - 1]}, CI, "");
1635-
1636-
return getSPIRVFuncName(WrappedOC);
1637-
},
1638-
[=](CallInst *NewCI) -> Instruction * {
1623+
mutateCallInst(CI, WrappedOC)
1624+
.mapArg(CI->arg_size() - 1,
1625+
[&](Value *Arg, Type *ParamTy) {
1626+
// Create conversion function call for the last operand
1627+
return std::pair<Value *, Type *>(
1628+
addCallInstSPIRV(M, getSPIRVFuncName(ToMCEOC), MCETy, Arg,
1629+
nullptr, {ParamTy}, CI, ""),
1630+
MCESTy);
1631+
})
1632+
.changeReturnType(MCETy, [=](IRBuilder<> &, CallInst *NewCI) {
16391633
// Create conversion function call for the return result
16401634
return addCallInstSPIRV(M, getSPIRVFuncName(FromMCEOC), CI->getType(),
16411635
NewCI, nullptr, {MCESTy}, CI, "");
1642-
},
1643-
&Attrs);
1636+
});
16441637
} else {
16451638
// Wrapper built-ins which take the 'result_t' argument requires only one
16461639
// conversion for the argument
1647-
mutateCallInstSPIRV(
1648-
M, CI,
1649-
[=](CallInst *, std::vector<Value *> &Args) {
1650-
// Create conversion function call for the last
1651-
// operand
1652-
Args[Args.size() - 1] = addCallInstSPIRV(
1653-
M, getSPIRVFuncName(ToMCEOC), MCETy, Args[Args.size() - 1],
1654-
nullptr, {ParamTys[Args.size() - 1]}, CI, "");
1655-
1656-
return getSPIRVFuncName(WrappedOC);
1657-
},
1658-
&Attrs);
1640+
mutateCallInst(CI, WrappedOC)
1641+
.mapArg(CI->arg_size() - 1, [&](Value *Arg, Type *ParamTy) {
1642+
// Create conversion function call for the last operand
1643+
return std::pair<Value *, Type *>(
1644+
addCallInstSPIRV(M, getSPIRVFuncName(ToMCEOC), MCETy, Arg,
1645+
nullptr, {ParamTy}, CI, ""),
1646+
MCESTy);
1647+
});
16591648
}
16601649
}
16611650

@@ -1677,45 +1666,44 @@ void OCLToSPIRVBase::visitSubgroupAVCBuiltinCallWithSampler(
16771666
if (OC == OpNop)
16781667
return; // this is not a VME built-in
16791668

1680-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1681-
mutateCallInstSPIRV(
1682-
M, CI,
1683-
[=](CallInst *, std::vector<Value *> &Args) {
1684-
SmallVector<Type *, 4> ParamTys;
1685-
getParameterTypes(CI, ParamTys);
1686-
auto *TyIt =
1687-
std::find_if(ParamTys.begin(), ParamTys.end(), isSamplerStructTy);
1688-
assert(TyIt != ParamTys.end() &&
1689-
"Invalid Subgroup AVC Intel built-in call");
1690-
auto SamplerIt = Args.begin() + (TyIt - ParamTys.begin());
1691-
auto *SamplerVal = *SamplerIt;
1692-
auto *SamplerTy = *TyIt;
1693-
Args.erase(SamplerIt);
1694-
ParamTys.erase(TyIt);
1695-
1696-
for (unsigned I = 0, E = Args.size(); I < E; ++I) {
1697-
if (!isOCLImageStructType(ParamTys[I]))
1698-
continue;
1699-
1700-
auto *ImageTy =
1701-
OCLTypeToSPIRVPtr
1702-
->getAdaptedArgumentType(CI->getCalledFunction(), I)
1703-
.second;
1704-
if (!ImageTy)
1705-
ImageTy = ParamTys[I];
1706-
ImageTy = adaptSPIRVImageType(M, ImageTy);
1707-
auto *SampledImgTy = getSPIRVTypeByChangeBaseTypeName(
1708-
M, ImageTy, kSPIRVTypeName::Image, kSPIRVTypeName::VmeImageINTEL);
1709-
1710-
Value *SampledImgArgs[] = {Args[I], SamplerVal};
1711-
Args[I] = addCallInstSPIRV(M, getSPIRVFuncName(OpVmeImageINTEL),
1712-
SampledImgTy, SampledImgArgs, nullptr,
1713-
{ParamTys[I], SamplerTy}, CI,
1714-
kSPIRVName::TempSampledImage);
1715-
}
1716-
return getSPIRVFuncName(OC);
1717-
},
1718-
&Attrs);
1669+
SmallVector<Type *, 4> ParamTys;
1670+
getParameterTypes(CI, ParamTys);
1671+
auto *TyIt =
1672+
std::find_if(ParamTys.begin(), ParamTys.end(), isSamplerStructTy);
1673+
assert(TyIt != ParamTys.end() && "Invalid Subgroup AVC Intel built-in call");
1674+
unsigned SamplerIndex = TyIt - ParamTys.begin();
1675+
Value *SamplerVal = CI->getOperand(SamplerIndex);
1676+
Type *SamplerTy = ParamTys[SamplerIndex];
1677+
1678+
SmallVector<Type *, 4> AdaptedTys;
1679+
for (unsigned I = 0; I < CI->arg_size(); I++)
1680+
AdaptedTys.push_back(
1681+
OCLTypeToSPIRVPtr->getAdaptedArgumentType(CI->getCalledFunction(), I)
1682+
.second);
1683+
auto *AdaptedIter = AdaptedTys.begin();
1684+
1685+
mutateCallInst(CI, OC)
1686+
.mapArgs([&](Value *Arg, Type *PointerTy) {
1687+
if (!isOCLImageStructType(PointerTy))
1688+
return std::make_pair(Arg, PointerTy);
1689+
1690+
auto *ImageTy = *AdaptedIter++;
1691+
if (!ImageTy)
1692+
ImageTy = PointerTy;
1693+
ImageTy = adaptSPIRVImageType(M, ImageTy);
1694+
auto *SampledImgStructTy = getSPIRVStructTypeByChangeBaseTypeName(
1695+
M, ImageTy, kSPIRVTypeName::Image, kSPIRVTypeName::VmeImageINTEL);
1696+
auto *SampledImgTy =
1697+
PointerType::get(SampledImgStructTy, SPIRAS_Global);
1698+
1699+
Value *SampledImgArgs[] = {Arg, SamplerVal};
1700+
return std::pair<Value *, Type *>(
1701+
addCallInstSPIRV(M, getSPIRVFuncName(OpVmeImageINTEL), SampledImgTy,
1702+
SampledImgArgs, nullptr, {PointerTy, SamplerTy},
1703+
CI, kSPIRVName::TempSampledImage),
1704+
SampledImgStructTy);
1705+
})
1706+
.removeArg(SamplerIndex);
17191707
}
17201708

17211709
void OCLToSPIRVBase::visitCallSplitBarrierINTEL(CallInst *CI,

lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 63 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -848,74 +848,69 @@ void SPIRVToOCLBase::visitCallSPIRVSubgroupINTELBuiltIn(CallInst *CI, Op OC) {
848848

849849
void SPIRVToOCLBase::visitCallSPIRVAvcINTELEvaluateBuiltIn(CallInst *CI,
850850
Op OC) {
851-
assert(CI->getCalledFunction() && "Unexpected indirect call");
852-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
853-
mutateCallInstOCL(
854-
M, CI,
855-
[=](CallInst *, std::vector<Value *> &Args) {
856-
// There are three types of AVC Intel Evaluate opcodes:
857-
// 1. With multi reference images - does not use OpVmeImageINTEL opcode
858-
// for reference images
859-
// 2. With dual reference images - uses two OpVmeImageINTEL opcodes for
860-
// reference image
861-
// 3. With single reference image - uses one OpVmeImageINTEL opcode for
862-
// reference image
863-
StringRef FnName = CI->getCalledFunction()->getName();
864-
int NumImages = 0;
865-
if (FnName.contains("SingleReference"))
866-
NumImages = 2;
867-
else if (FnName.contains("DualReference"))
868-
NumImages = 3;
869-
else if (FnName.contains("MultiReference"))
870-
NumImages = 1;
871-
else if (FnName.contains("EvaluateIpe"))
872-
NumImages = 1;
873-
874-
auto EraseVmeImageCall = [](CallInst *CI) {
875-
if (CI->hasOneUse()) {
876-
CI->replaceAllUsesWith(UndefValue::get(CI->getType()));
877-
CI->dropAllReferences();
878-
CI->eraseFromParent();
879-
}
880-
};
881-
if (NumImages) {
882-
CallInst *SrcImage = cast<CallInst>(Args[0]);
883-
if (NumImages == 1) {
884-
// Multi reference opcode - remove src image OpVmeImageINTEL opcode
885-
// and replace it with corresponding OpImage and OpSampler arguments
886-
size_t SamplerPos = Args.size() - 1;
887-
Args.erase(Args.begin(), Args.begin() + 1);
888-
Args.insert(Args.begin(), SrcImage->getOperand(0));
889-
Args.insert(Args.begin() + SamplerPos, SrcImage->getOperand(1));
890-
EraseVmeImageCall(SrcImage);
891-
} else {
892-
CallInst *FwdRefImage = cast<CallInst>(Args[1]);
893-
CallInst *BwdRefImage =
894-
NumImages == 3 ? cast<CallInst>(Args[2]) : nullptr;
895-
// Single reference opcode - remove src and ref image
896-
// OpVmeImageINTEL opcodes and replace them with src and ref OpImage
897-
// opcodes and OpSampler
898-
Args.erase(Args.begin(), Args.begin() + NumImages);
899-
// insert source OpImage and OpSampler
900-
auto SrcOps = SrcImage->args();
901-
Args.insert(Args.begin(), SrcOps.begin(), SrcOps.end());
902-
// insert reference OpImage
903-
Args.insert(Args.begin() + 1, FwdRefImage->getOperand(0));
904-
EraseVmeImageCall(SrcImage);
905-
EraseVmeImageCall(FwdRefImage);
906-
if (BwdRefImage) {
907-
// Dual reference opcode - insert second reference OpImage
908-
// argument
909-
Args.insert(Args.begin() + 2, BwdRefImage->getOperand(0));
910-
EraseVmeImageCall(BwdRefImage);
911-
}
912-
}
913-
} else
914-
llvm_unreachable("invalid avc instruction");
915-
916-
return OCLSPIRVSubgroupAVCIntelBuiltinMap::rmap(OC);
917-
},
918-
&Attrs);
851+
// There are three types of AVC Intel Evaluate opcodes:
852+
// 1. With multi reference images - does not use OpVmeImageINTEL opcode
853+
// for reference images
854+
// 2. With dual reference images - uses two OpVmeImageINTEL opcodes for
855+
// reference image
856+
// 3. With single reference image - uses one OpVmeImageINTEL opcode for
857+
// reference image
858+
StringRef FnName = CI->getCalledFunction()->getName();
859+
int NumImages = 0;
860+
if (FnName.contains("SingleReference"))
861+
NumImages = 2;
862+
else if (FnName.contains("DualReference"))
863+
NumImages = 3;
864+
else if (FnName.contains("MultiReference"))
865+
NumImages = 1;
866+
else if (FnName.contains("EvaluateIpe"))
867+
NumImages = 1;
868+
869+
auto EraseVmeImageCall = [](CallInst *CI) {
870+
if (CI->hasOneUse()) {
871+
CI->replaceAllUsesWith(UndefValue::get(CI->getType()));
872+
CI->dropAllReferences();
873+
CI->eraseFromParent();
874+
}
875+
};
876+
877+
auto Mutator =
878+
mutateCallInst(CI, OCLSPIRVSubgroupAVCIntelBuiltinMap::rmap(OC));
879+
if (NumImages) {
880+
CallInst *SrcImage = cast<CallInst>(Mutator.getArg(0));
881+
SmallVector<Type *, 2> SrcImageTys;
882+
getParameterTypes(SrcImage, SrcImageTys);
883+
if (NumImages == 1) {
884+
// Multi reference opcode - remove src image OpVmeImageINTEL opcode
885+
// and replace it with corresponding OpImage and OpSampler arguments
886+
size_t SamplerPos = Mutator.arg_size() - 1;
887+
Mutator.replaceArg(0, {SrcImage->getOperand(0), SrcImageTys[0]});
888+
Mutator.insertArg(SamplerPos, {SrcImage->getOperand(1), SrcImageTys[1]});
889+
} else {
890+
CallInst *FwdRefImage = cast<CallInst>(Mutator.getArg(1));
891+
CallInst *BwdRefImage =
892+
NumImages == 3 ? cast<CallInst>(Mutator.getArg(2)) : nullptr;
893+
// Single reference opcode - remove src and ref image
894+
// OpVmeImageINTEL opcodes and replace them with src and ref OpImage
895+
// opcodes and OpSampler
896+
Mutator.removeArgs(0, NumImages);
897+
// insert source OpImage and OpSampler
898+
Mutator.insertArg(0, {SrcImage->getOperand(0), SrcImageTys[0]});
899+
Mutator.insertArg(1, {SrcImage->getOperand(1), SrcImageTys[1]});
900+
// insert reference OpImage
901+
getParameterTypes(FwdRefImage, SrcImageTys);
902+
Mutator.insertArg(1, {FwdRefImage->getOperand(0), SrcImageTys[0]});
903+
EraseVmeImageCall(SrcImage);
904+
EraseVmeImageCall(FwdRefImage);
905+
if (BwdRefImage) {
906+
// Dual reference opcode - insert second reference OpImage argument
907+
getParameterTypes(BwdRefImage, SrcImageTys);
908+
Mutator.insertArg(2, {BwdRefImage->getOperand(0), SrcImageTys[0]});
909+
EraseVmeImageCall(BwdRefImage);
910+
}
911+
}
912+
} else
913+
llvm_unreachable("invalid avc instruction");
919914
}
920915

921916
void SPIRVToOCLBase::visitCallSPIRVGenericPtrMemSemantics(CallInst *CI) {

0 commit comments

Comments
 (0)