Skip to content

Commit 805bc3d

Browse files
jcranmer-inteldbudanov-cmplr
authored andcommitted
Migrate miscellaneous code in OCLToSPIRV to the mutator interface. (#1621)
Original commit: KhronosGroup/SPIRV-LLVM-Translator@a33e094
1 parent 38a25d1 commit 805bc3d

File tree

1 file changed

+61
-74
lines changed

1 file changed

+61
-74
lines changed

llvm-spirv/lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 61 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -941,69 +941,60 @@ void OCLToSPIRVBase::visitCallReadImageWithSampler(CallInst *CI,
941941
assert(MangledName.find(kMangledName::Sampler) != StringRef::npos);
942942
assert(CI->getCalledFunction() && "Unexpected indirect call");
943943
Function *Func = CI->getCalledFunction();
944-
AttributeList Attrs = Func->getAttributes();
945944
bool IsRetScalar = !CI->getType()->isVectorTy();
946945
SmallVector<Type *, 3> ArgStructTys;
947946
getParameterTypes(CI, ArgStructTys);
948-
mutateCallInstSPIRV(
949-
M, CI,
950-
[=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
951-
auto *ImageTy =
952-
OCLTypeToSPIRVPtr->getAdaptedArgumentType(Func, 0).second;
953-
if (!ImageTy)
954-
ImageTy = ArgStructTys[0];
955-
ImageTy = adaptSPIRVImageType(M, ImageTy);
956-
auto SampledImgTy = getSPIRVTypeByChangeBaseTypeName(
957-
M, ImageTy, kSPIRVTypeName::Image, kSPIRVTypeName::SampledImg);
958-
Value *SampledImgArgs[] = {Args[0], Args[1]};
959-
auto SampledImg = addCallInstSPIRV(
960-
M, getSPIRVFuncName(OpSampledImage), SampledImgTy, SampledImgArgs,
961-
nullptr, {ArgStructTys[0], ArgStructTys[1]}, CI,
962-
kSPIRVName::TempSampledImage);
963-
964-
Args[0] = SampledImg;
965-
Args.erase(Args.begin() + 1, Args.begin() + 2);
966-
967-
unsigned ImgOpMask = getImageSignZeroExt(DemangledName);
968-
unsigned ImgOpMaskInsIndex = Args.size();
969-
switch (Args.size()) {
970-
case 2: // no lod
971-
ImgOpMask |= ImageOperandsMask::ImageOperandsLodMask;
972-
ImgOpMaskInsIndex = Args.size();
973-
Args.push_back(getFloat32(M, 0.f));
974-
break;
975-
case 3: // explicit lod
976-
ImgOpMask |= ImageOperandsMask::ImageOperandsLodMask;
977-
ImgOpMaskInsIndex = 2;
978-
break;
979-
case 4: // gradient
980-
ImgOpMask |= ImageOperandsMask::ImageOperandsGradMask;
981-
ImgOpMaskInsIndex = 2;
982-
break;
983-
default:
984-
assert(0 && "read_image* with unhandled number of args!");
985-
}
986-
Args.insert(Args.begin() + ImgOpMaskInsIndex, getInt32(M, ImgOpMask));
987-
988-
// SPIR-V instruction always returns 4-element vector
989-
if (IsRetScalar)
990-
Ret = FixedVectorType::get(Ret, 4);
991-
return getSPIRVFuncName(OpImageSampleExplicitLod,
992-
std::string(kSPIRVPostfix::ExtDivider) +
993-
getPostfixForReturnType(Ret));
994-
},
995-
[&](CallInst *CI) -> Instruction * {
996-
if (IsRetScalar)
997-
return ExtractElementInst::Create(CI, getSizet(M, 0), "",
998-
CI->getNextNode());
999-
return CI;
1000-
},
1001-
&Attrs);
947+
Type *Ret = CI->getType();
948+
auto *ImageTy = OCLTypeToSPIRVPtr->getAdaptedArgumentType(Func, 0).second;
949+
if (!ImageTy)
950+
ImageTy = ArgStructTys[0];
951+
ImageTy = adaptSPIRVImageType(M, ImageTy);
952+
auto *SampledImgStructTy = getSPIRVStructTypeByChangeBaseTypeName(
953+
M, ImageTy, kSPIRVTypeName::Image, kSPIRVTypeName::SampledImg);
954+
auto *SampledImgTy = PointerType::get(SampledImgStructTy, SPIRAS_Global);
955+
Value *SampledImgArgs[] = {CI->getArgOperand(0), CI->getArgOperand(1)};
956+
auto *SampledImg = addCallInstSPIRV(M, getSPIRVFuncName(OpSampledImage),
957+
SampledImgTy, SampledImgArgs, nullptr,
958+
{ArgStructTys[0], ArgStructTys[1]}, CI,
959+
kSPIRVName::TempSampledImage);
960+
961+
auto Mutator = mutateCallInst(
962+
CI, getSPIRVFuncName(OpImageSampleExplicitLod,
963+
std::string(kSPIRVPostfix::ExtDivider) +
964+
getPostfixForReturnType(Ret)));
965+
Mutator.replaceArg(0, {SampledImg, SampledImgStructTy}).removeArg(1);
966+
unsigned ImgOpMask = getImageSignZeroExt(DemangledName);
967+
unsigned ImgOpMaskInsIndex = Mutator.arg_size();
968+
switch (Mutator.arg_size()) {
969+
case 2: // no lod
970+
ImgOpMask |= ImageOperandsMask::ImageOperandsLodMask;
971+
ImgOpMaskInsIndex = Mutator.arg_size();
972+
Mutator.appendArg(getFloat32(M, 0.f));
973+
break;
974+
case 3: // explicit lod
975+
ImgOpMask |= ImageOperandsMask::ImageOperandsLodMask;
976+
ImgOpMaskInsIndex = 2;
977+
break;
978+
case 4: // gradient
979+
ImgOpMask |= ImageOperandsMask::ImageOperandsGradMask;
980+
ImgOpMaskInsIndex = 2;
981+
break;
982+
default:
983+
assert(0 && "read_image* with unhandled number of args!");
984+
}
985+
Mutator.insertArg(ImgOpMaskInsIndex, getInt32(M, ImgOpMask));
986+
987+
// SPIR-V instruction always returns 4-element vector
988+
if (IsRetScalar)
989+
Mutator.changeReturnType(FixedVectorType::get(Ret, 4),
990+
[=](IRBuilder<> &Builder, CallInst *NewCI) {
991+
return Builder.CreateExtractElement(
992+
NewCI, getSizet(M, 0));
993+
});
1002994
}
1003995

1004996
void OCLToSPIRVBase::visitCallGetImageSize(CallInst *CI,
1005997
StringRef DemangledName) {
1006-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1007998
StringRef TyName;
1008999
SmallVector<StringRef, 4> SubStrs;
10091000
SmallVector<Type *, 4> ParamTys;
@@ -1015,22 +1006,19 @@ void OCLToSPIRVBase::visitCallGetImageSize(CallInst *CI,
10151006
auto Desc = map<SPIRVTypeImageDescriptor>(ImageTyName);
10161007
unsigned Dim = getImageDimension(Desc.Dim) + Desc.Arrayed;
10171008
assert(Dim > 0 && "Invalid image dimension.");
1018-
mutateCallInstSPIRV(
1019-
M, CI,
1020-
[&](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
1021-
assert(Args.size() == 1);
1022-
Ret = CI->getType()->isIntegerTy(64) ? Type::getInt64Ty(*Ctx)
1023-
: Type::getInt32Ty(*Ctx);
1024-
if (Dim > 1)
1025-
Ret = FixedVectorType::get(Ret, Dim);
1026-
if (Desc.Dim == DimBuffer)
1027-
return getSPIRVFuncName(OpImageQuerySize, CI->getType());
1028-
else {
1029-
Args.push_back(getInt32(M, 0));
1030-
return getSPIRVFuncName(OpImageQuerySizeLod, CI->getType());
1031-
}
1032-
},
1033-
[&](CallInst *NCI) -> Instruction * {
1009+
assert(CI->arg_size() == 1);
1010+
Type *NewRet = CI->getType()->isIntegerTy(64) ? Type::getInt64Ty(*Ctx)
1011+
: Type::getInt32Ty(*Ctx);
1012+
if (Dim > 1)
1013+
NewRet = FixedVectorType::get(NewRet, Dim);
1014+
auto Mutator = mutateCallInst(CI, getSPIRVFuncName(Desc.Dim == DimBuffer
1015+
? OpImageQuerySize
1016+
: OpImageQuerySizeLod,
1017+
CI->getType()));
1018+
if (Desc.Dim != DimBuffer)
1019+
Mutator.appendArg(getInt32(M, 0));
1020+
Mutator.changeReturnType(
1021+
NewRet, [&](IRBuilder<> &, CallInst *NCI) -> Value * {
10341022
if (Dim == 1)
10351023
return NCI;
10361024
if (DemangledName == kOCLBuiltinName::GetImageDim) {
@@ -1059,8 +1047,7 @@ void OCLToSPIRVBase::visitCallGetImageSize(CallInst *CI,
10591047
.Case(kOCLBuiltinName::GetImageArraySize, Dim - 1);
10601048
return ExtractElementInst::Create(NCI, getUInt32(M, I), "",
10611049
NCI->getNextNode());
1062-
},
1063-
&Attrs);
1050+
});
10641051
}
10651052

10661053
/// Remove trivial conversion functions

0 commit comments

Comments
 (0)