@@ -941,69 +941,60 @@ void OCLToSPIRVBase::visitCallReadImageWithSampler(CallInst *CI,
941
941
assert (MangledName.find (kMangledName ::Sampler) != StringRef::npos);
942
942
assert (CI->getCalledFunction () && " Unexpected indirect call" );
943
943
Function *Func = CI->getCalledFunction ();
944
- AttributeList Attrs = Func->getAttributes ();
945
944
bool IsRetScalar = !CI->getType ()->isVectorTy ();
946
945
SmallVector<Type *, 3 > ArgStructTys;
947
946
getParameterTypes (CI, ArgStructTys);
948
- mutateCallInstSPIRV (
949
- M, CI,
950
- [=](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
951
- auto *ImageTy =
952
- OCLTypeToSPIRVPtr->getAdaptedArgumentType (Func, 0 ).second ;
953
- if (!ImageTy)
954
- ImageTy = ArgStructTys[0 ];
955
- ImageTy = adaptSPIRVImageType (M, ImageTy);
956
- auto SampledImgTy = getSPIRVTypeByChangeBaseTypeName (
957
- M, ImageTy, kSPIRVTypeName ::Image, kSPIRVTypeName ::SampledImg);
958
- Value *SampledImgArgs[] = {Args[0 ], Args[1 ]};
959
- auto SampledImg = addCallInstSPIRV (
960
- M, getSPIRVFuncName (OpSampledImage), SampledImgTy, SampledImgArgs,
961
- nullptr , {ArgStructTys[0 ], ArgStructTys[1 ]}, CI,
962
- kSPIRVName ::TempSampledImage);
963
-
964
- Args[0 ] = SampledImg;
965
- Args.erase (Args.begin () + 1 , Args.begin () + 2 );
966
-
967
- unsigned ImgOpMask = getImageSignZeroExt (DemangledName);
968
- unsigned ImgOpMaskInsIndex = Args.size ();
969
- switch (Args.size ()) {
970
- case 2 : // no lod
971
- ImgOpMask |= ImageOperandsMask::ImageOperandsLodMask;
972
- ImgOpMaskInsIndex = Args.size ();
973
- Args.push_back (getFloat32 (M, 0 .f ));
974
- break ;
975
- case 3 : // explicit lod
976
- ImgOpMask |= ImageOperandsMask::ImageOperandsLodMask;
977
- ImgOpMaskInsIndex = 2 ;
978
- break ;
979
- case 4 : // gradient
980
- ImgOpMask |= ImageOperandsMask::ImageOperandsGradMask;
981
- ImgOpMaskInsIndex = 2 ;
982
- break ;
983
- default :
984
- assert (0 && " read_image* with unhandled number of args!" );
985
- }
986
- Args.insert (Args.begin () + ImgOpMaskInsIndex, getInt32 (M, ImgOpMask));
987
-
988
- // SPIR-V instruction always returns 4-element vector
989
- if (IsRetScalar)
990
- Ret = FixedVectorType::get (Ret, 4 );
991
- return getSPIRVFuncName (OpImageSampleExplicitLod,
992
- std::string (kSPIRVPostfix ::ExtDivider) +
993
- getPostfixForReturnType (Ret));
994
- },
995
- [&](CallInst *CI) -> Instruction * {
996
- if (IsRetScalar)
997
- return ExtractElementInst::Create (CI, getSizet (M, 0 ), " " ,
998
- CI->getNextNode ());
999
- return CI;
1000
- },
1001
- &Attrs);
947
+ Type *Ret = CI->getType ();
948
+ auto *ImageTy = OCLTypeToSPIRVPtr->getAdaptedArgumentType (Func, 0 ).second ;
949
+ if (!ImageTy)
950
+ ImageTy = ArgStructTys[0 ];
951
+ ImageTy = adaptSPIRVImageType (M, ImageTy);
952
+ auto *SampledImgStructTy = getSPIRVStructTypeByChangeBaseTypeName (
953
+ M, ImageTy, kSPIRVTypeName ::Image, kSPIRVTypeName ::SampledImg);
954
+ auto *SampledImgTy = PointerType::get (SampledImgStructTy, SPIRAS_Global);
955
+ Value *SampledImgArgs[] = {CI->getArgOperand (0 ), CI->getArgOperand (1 )};
956
+ auto *SampledImg = addCallInstSPIRV (M, getSPIRVFuncName (OpSampledImage),
957
+ SampledImgTy, SampledImgArgs, nullptr ,
958
+ {ArgStructTys[0 ], ArgStructTys[1 ]}, CI,
959
+ kSPIRVName ::TempSampledImage);
960
+
961
+ auto Mutator = mutateCallInst (
962
+ CI, getSPIRVFuncName (OpImageSampleExplicitLod,
963
+ std::string (kSPIRVPostfix ::ExtDivider) +
964
+ getPostfixForReturnType (Ret)));
965
+ Mutator.replaceArg (0 , {SampledImg, SampledImgStructTy}).removeArg (1 );
966
+ unsigned ImgOpMask = getImageSignZeroExt (DemangledName);
967
+ unsigned ImgOpMaskInsIndex = Mutator.arg_size ();
968
+ switch (Mutator.arg_size ()) {
969
+ case 2 : // no lod
970
+ ImgOpMask |= ImageOperandsMask::ImageOperandsLodMask;
971
+ ImgOpMaskInsIndex = Mutator.arg_size ();
972
+ Mutator.appendArg (getFloat32 (M, 0 .f ));
973
+ break ;
974
+ case 3 : // explicit lod
975
+ ImgOpMask |= ImageOperandsMask::ImageOperandsLodMask;
976
+ ImgOpMaskInsIndex = 2 ;
977
+ break ;
978
+ case 4 : // gradient
979
+ ImgOpMask |= ImageOperandsMask::ImageOperandsGradMask;
980
+ ImgOpMaskInsIndex = 2 ;
981
+ break ;
982
+ default :
983
+ assert (0 && " read_image* with unhandled number of args!" );
984
+ }
985
+ Mutator.insertArg (ImgOpMaskInsIndex, getInt32 (M, ImgOpMask));
986
+
987
+ // SPIR-V instruction always returns 4-element vector
988
+ if (IsRetScalar)
989
+ Mutator.changeReturnType (FixedVectorType::get (Ret, 4 ),
990
+ [=](IRBuilder<> &Builder, CallInst *NewCI) {
991
+ return Builder.CreateExtractElement (
992
+ NewCI, getSizet (M, 0 ));
993
+ });
1002
994
}
1003
995
1004
996
void OCLToSPIRVBase::visitCallGetImageSize (CallInst *CI,
1005
997
StringRef DemangledName) {
1006
- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
1007
998
StringRef TyName;
1008
999
SmallVector<StringRef, 4 > SubStrs;
1009
1000
SmallVector<Type *, 4 > ParamTys;
@@ -1015,22 +1006,19 @@ void OCLToSPIRVBase::visitCallGetImageSize(CallInst *CI,
1015
1006
auto Desc = map<SPIRVTypeImageDescriptor>(ImageTyName);
1016
1007
unsigned Dim = getImageDimension (Desc.Dim ) + Desc.Arrayed ;
1017
1008
assert (Dim > 0 && " Invalid image dimension." );
1018
- mutateCallInstSPIRV (
1019
- M, CI,
1020
- [&](CallInst *, std::vector<Value *> &Args, Type *&Ret) {
1021
- assert (Args.size () == 1 );
1022
- Ret = CI->getType ()->isIntegerTy (64 ) ? Type::getInt64Ty (*Ctx)
1023
- : Type::getInt32Ty (*Ctx);
1024
- if (Dim > 1 )
1025
- Ret = FixedVectorType::get (Ret, Dim);
1026
- if (Desc.Dim == DimBuffer)
1027
- return getSPIRVFuncName (OpImageQuerySize, CI->getType ());
1028
- else {
1029
- Args.push_back (getInt32 (M, 0 ));
1030
- return getSPIRVFuncName (OpImageQuerySizeLod, CI->getType ());
1031
- }
1032
- },
1033
- [&](CallInst *NCI) -> Instruction * {
1009
+ assert (CI->arg_size () == 1 );
1010
+ Type *NewRet = CI->getType ()->isIntegerTy (64 ) ? Type::getInt64Ty (*Ctx)
1011
+ : Type::getInt32Ty (*Ctx);
1012
+ if (Dim > 1 )
1013
+ NewRet = FixedVectorType::get (NewRet, Dim);
1014
+ auto Mutator = mutateCallInst (CI, getSPIRVFuncName (Desc.Dim == DimBuffer
1015
+ ? OpImageQuerySize
1016
+ : OpImageQuerySizeLod,
1017
+ CI->getType ()));
1018
+ if (Desc.Dim != DimBuffer)
1019
+ Mutator.appendArg (getInt32 (M, 0 ));
1020
+ Mutator.changeReturnType (
1021
+ NewRet, [&](IRBuilder<> &, CallInst *NCI) -> Value * {
1034
1022
if (Dim == 1 )
1035
1023
return NCI;
1036
1024
if (DemangledName == kOCLBuiltinName ::GetImageDim) {
@@ -1059,8 +1047,7 @@ void OCLToSPIRVBase::visitCallGetImageSize(CallInst *CI,
1059
1047
.Case (kOCLBuiltinName ::GetImageArraySize, Dim - 1 );
1060
1048
return ExtractElementInst::Create (NCI, getUInt32 (M, I), " " ,
1061
1049
NCI->getNextNode ());
1062
- },
1063
- &Attrs);
1050
+ });
1064
1051
}
1065
1052
1066
1053
// / Remove trivial conversion functions
0 commit comments