Skip to content

Commit 38a25d1

Browse files
jcranmer-inteldbudanov-cmplr
authored andcommitted
Migrate miscellaneous code in SPIRVToOCL to the mutator interface.
The test change is because use of an IRBuilder causes extra constant folding to happen that previously didn't happen. Original commit: KhronosGroup/SPIRV-LLVM-Translator@44c70a8
1 parent 1e7915e commit 38a25d1

File tree

1 file changed

+99
-126
lines changed

1 file changed

+99
-126
lines changed

llvm-spirv/lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 99 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -535,48 +535,39 @@ static bool needsInt32RetTy(Op OC) {
535535

536536
void SPIRVToOCLBase::visitCallSPIRVGroupBuiltin(CallInst *CI, Op OC) {
537537
auto FuncName = groupOCToOCLBuiltinName(CI, OC);
538-
auto ModifyArguments = [=](CallInst *, std::vector<Value *> &Args,
539-
llvm::Type *&RetTy) {
540-
Type *Int32Ty = Type::getInt32Ty(*Ctx);
541-
bool HasArg0ExtendedToi32 =
542-
OC == OpGroupAny || OC == OpGroupAll || OC == OpGroupNonUniformAny ||
543-
OC == OpGroupNonUniformAll || OC == OpGroupNonUniformBallot ||
544-
isGroupLogicalOpCode(OC);
545-
/// Remove Group Operation argument,
546-
/// as in OpenCL representation this is included in the function name
547-
Args.erase(Args.begin(), Args.begin() + (hasGroupOperation(OC) ? 2 : 1));
548-
549-
// Handle function arguments
550-
if (OC == OpGroupBroadcast)
551-
expandVector(CI, Args, 1);
552-
else if (HasArg0ExtendedToi32)
553-
Args[0] = CastInst::CreateZExtOrBitCast(Args[0], Int32Ty, "", CI);
554-
555-
// Handle function return type
556-
if (needsInt32RetTy(OC))
557-
RetTy = Int32Ty;
558-
559-
return FuncName;
560-
};
561-
auto ModifyRetTy = [=](CallInst *CI) -> Instruction * {
562-
if (needsInt32RetTy(OC)) {
538+
auto Mutator = mutateCallInst(CI, FuncName);
539+
/// Remove Group Operation argument,
540+
/// as in OpenCL representation this is included in the function name
541+
Mutator.removeArgs(0, (hasGroupOperation(OC) ? 2 : 1));
542+
543+
Type *Int32Ty = Type::getInt32Ty(*Ctx);
544+
bool HasArg0ExtendedToi32 =
545+
OC == OpGroupAny || OC == OpGroupAll || OC == OpGroupNonUniformAny ||
546+
OC == OpGroupNonUniformAll || OC == OpGroupNonUniformBallot ||
547+
isGroupLogicalOpCode(OC);
548+
549+
// Handle function arguments
550+
if (OC == OpGroupBroadcast) {
551+
Value *VecArg = Mutator.getArg(1);
552+
if (auto *VT = dyn_cast<FixedVectorType>(VecArg->getType())) {
553+
unsigned NumElements = VT->getNumElements();
554+
for (unsigned I = 0; I < NumElements; I++)
555+
Mutator.insertArg(1 + I, Mutator.Builder.CreateExtractElement(
556+
VecArg, Mutator.Builder.getInt32(I)));
557+
Mutator.removeArg(1 + NumElements);
558+
}
559+
} else if (HasArg0ExtendedToi32)
560+
Mutator.mapArg(0, [](IRBuilder<> &Builder, Value *V) {
561+
return Builder.CreateZExt(V, Builder.getInt32Ty());
562+
});
563+
564+
// Handle function return type
565+
if (needsInt32RetTy(OC))
566+
Mutator.changeReturnType(Int32Ty, [](IRBuilder<> &Builder, CallInst *CI) {
563567
// The OpenCL builtin returns a non-zero integer value. Convert to a
564568
// boolean value.
565-
Constant *Zero = ConstantInt::get(CI->getType(), 0);
566-
return new ICmpInst(CI->getNextNode(), CmpInst::ICMP_NE, CI, Zero);
567-
} else
568-
return CI;
569-
};
570-
571-
assert(CI->getCalledFunction() && "Unexpected indirect call");
572-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
573-
SmallVector<AttributeSet, 2> ArgAttrs;
574-
for (int I = (hasGroupOperation(OC) ? 2 : 1);
575-
I < (int)Attrs.getNumAttrSets() - 2; I++)
576-
ArgAttrs.push_back(Attrs.getParamAttrs(I));
577-
Attrs = AttributeList::get(*Ctx, Attrs.getFnAttrs(), Attrs.getRetAttrs(),
578-
ArgAttrs);
579-
mutateCallInstOCL(M, CI, ModifyArguments, ModifyRetTy, &Attrs);
569+
return Builder.CreateICmpNE(CI, Builder.getInt32(0));
570+
});
580571
}
581572

582573
void SPIRVToOCLBase::visitCallSPIRVPipeBuiltin(CallInst *CI, Op OC) {
@@ -586,64 +577,52 @@ void SPIRVToOCLBase::visitCallSPIRVPipeBuiltin(CallInst *CI, Op OC) {
586577
DemangledName = getGroupBuiltinPrefix(CI) + DemangledName;
587578

588579
assert(CI->getCalledFunction() && "Unexpected indirect call");
589-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
590-
mutateCallInstOCL(
591-
M, CI,
592-
[=](CallInst *, std::vector<Value *> &Args) {
593-
if (HasScope)
594-
Args.erase(Args.begin(), Args.begin() + 1);
595-
596-
if (!(OC == OpReadPipe || OC == OpWritePipe ||
597-
OC == OpReservedReadPipe || OC == OpReservedWritePipe ||
598-
OC == OpReadPipeBlockingINTEL || OC == OpWritePipeBlockingINTEL))
599-
return DemangledName;
600-
601-
auto &P = Args[Args.size() - 3];
602-
auto T = P->getType();
603-
assert(isa<PointerType>(T));
604-
auto *NewTy = PointerType::getInt8PtrTy(*Ctx, SPIRAS_Generic);
605-
if (T != NewTy) {
606-
P = CastInst::CreatePointerBitCastOrAddrSpaceCast(P, NewTy, "", CI);
607-
}
608-
return DemangledName;
609-
},
610-
&Attrs);
580+
auto Mutator = mutateCallInst(CI, DemangledName);
581+
if (HasScope)
582+
Mutator.removeArg(0);
583+
if (OC == OpReadPipe || OC == OpWritePipe || OC == OpReservedReadPipe ||
584+
OC == OpReservedWritePipe || OC == OpReadPipeBlockingINTEL ||
585+
OC == OpWritePipeBlockingINTEL) {
586+
Mutator.mapArg(Mutator.arg_size() - 3, [](IRBuilder<> &Builder, Value *P) {
587+
Type *T = P->getType();
588+
assert(isa<PointerType>(T));
589+
auto *NewTy = Builder.getInt8PtrTy(SPIRAS_Generic);
590+
if (T != NewTy) {
591+
P = Builder.CreatePointerBitCastOrAddrSpaceCast(P, NewTy);
592+
}
593+
return std::pair<Value *, Type *>(P, Builder.getInt8Ty());
594+
});
595+
}
611596
}
612597

613598
void SPIRVToOCLBase::visitCallSPIRVImageMediaBlockBuiltin(CallInst *CI, Op OC) {
614-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
615-
mutateCallInstOCL(
616-
M, CI,
617-
[=](CallInst *, std::vector<Value *> &Args) {
618-
// Moving the first argument to the end.
619-
std::rotate(Args.rbegin(), Args.rend() - 1, Args.rend());
620-
Type *RetType = CI->getType();
621-
if (OC == OpSubgroupImageMediaBlockWriteINTEL) {
622-
assert(Args.size() >= 4 && "Wrong media block write signature");
623-
RetType = Args.at(3)->getType(); // texel type
624-
}
625-
unsigned int BitWidth = RetType->getScalarSizeInBits();
626-
std::string FuncPostfix;
627-
if (BitWidth == 8)
628-
FuncPostfix = "_uc";
629-
else if (BitWidth == 16)
630-
FuncPostfix = "_us";
631-
else if (BitWidth == 32)
632-
FuncPostfix = "_ui";
633-
else
634-
assert(0 && "Unsupported texel type!");
599+
Type *RetType = CI->getType();
600+
if (OC == OpSubgroupImageMediaBlockWriteINTEL) {
601+
assert(CI->arg_size() >= 5 && "Wrong media block write signature");
602+
RetType = CI->getArgOperand(4)->getType(); // texel type
603+
}
604+
unsigned int BitWidth = RetType->getScalarSizeInBits();
605+
std::string FuncPostfix;
606+
if (BitWidth == 8)
607+
FuncPostfix = "_uc";
608+
else if (BitWidth == 16)
609+
FuncPostfix = "_us";
610+
else if (BitWidth == 32)
611+
FuncPostfix = "_ui";
612+
else
613+
assert(0 && "Unsupported texel type!");
635614

636-
if (auto *VecTy = dyn_cast<FixedVectorType>(RetType)) {
637-
unsigned int NumEl = VecTy->getNumElements();
638-
assert((NumEl == 2 || NumEl == 4 || NumEl == 8 || NumEl == 16) &&
639-
"Wrong function type!");
640-
FuncPostfix += std::to_string(NumEl);
641-
}
615+
if (auto *VecTy = dyn_cast<FixedVectorType>(RetType)) {
616+
unsigned int NumEl = VecTy->getNumElements();
617+
assert((NumEl == 2 || NumEl == 4 || NumEl == 8 || NumEl == 16) &&
618+
"Wrong function type!");
619+
FuncPostfix += std::to_string(NumEl);
620+
}
642621

643-
return OCLSPIRVBuiltinMap::rmap(OC) + FuncPostfix;
644-
},
645-
&Attrs);
622+
mutateCallInst(CI, OCLSPIRVBuiltinMap::rmap(OC) + FuncPostfix)
623+
.moveArg(0, CI->arg_size() - 1);
646624
}
625+
647626
void SPIRVToOCLBase::visitCallBuildNDRangeBuiltIn(CallInst *CI, Op OC,
648627
StringRef DemangledName) {
649628
assert(CI->getCalledFunction() && "Unexpected indirect call");
@@ -837,40 +816,34 @@ void SPIRVToOCLBase::visitCallSPIRVImageQueryBuiltIn(CallInst *CI, Op OC) {
837816
}
838817

839818
void SPIRVToOCLBase::visitCallSPIRVSubgroupINTELBuiltIn(CallInst *CI, Op OC) {
840-
assert(CI->getCalledFunction() && "Unexpected indirect call");
841-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
842-
mutateCallInstOCL(
843-
M, CI,
844-
[=](CallInst *, std::vector<Value *> &Args) {
845-
std::stringstream Name;
846-
Type *DataTy = nullptr;
847-
switch (OC) {
848-
case OpSubgroupBlockReadINTEL:
849-
case OpSubgroupImageBlockReadINTEL:
850-
Name << "intel_sub_group_block_read";
851-
DataTy = CI->getType();
852-
break;
853-
case OpSubgroupBlockWriteINTEL:
854-
Name << "intel_sub_group_block_write";
855-
DataTy = CI->getOperand(1)->getType();
856-
break;
857-
case OpSubgroupImageBlockWriteINTEL:
858-
Name << "intel_sub_group_block_write";
859-
DataTy = CI->getOperand(2)->getType();
860-
break;
861-
default:
862-
return OCLSPIRVBuiltinMap::rmap(OC);
863-
}
864-
assert(DataTy && "Intel subgroup block builtins should have data type");
865-
unsigned VectorNumElements = 1;
866-
if (FixedVectorType *VT = dyn_cast<FixedVectorType>(DataTy))
867-
VectorNumElements = VT->getNumElements();
868-
unsigned ElementBitSize = DataTy->getScalarSizeInBits();
869-
Name << getIntelSubgroupBlockDataPostfix(ElementBitSize,
870-
VectorNumElements);
871-
return Name.str();
872-
},
873-
&Attrs);
819+
std::stringstream Name;
820+
Type *DataTy = nullptr;
821+
switch (OC) {
822+
case OpSubgroupBlockReadINTEL:
823+
case OpSubgroupImageBlockReadINTEL:
824+
Name << "intel_sub_group_block_read";
825+
DataTy = CI->getType();
826+
break;
827+
case OpSubgroupBlockWriteINTEL:
828+
Name << "intel_sub_group_block_write";
829+
DataTy = CI->getOperand(1)->getType();
830+
break;
831+
case OpSubgroupImageBlockWriteINTEL:
832+
Name << "intel_sub_group_block_write";
833+
DataTy = CI->getOperand(2)->getType();
834+
break;
835+
default:
836+
Name << OCLSPIRVBuiltinMap::rmap(OC);
837+
break;
838+
}
839+
if (DataTy) {
840+
unsigned VectorNumElements = 1;
841+
if (FixedVectorType *VT = dyn_cast<FixedVectorType>(DataTy))
842+
VectorNumElements = VT->getNumElements();
843+
unsigned ElementBitSize = DataTy->getScalarSizeInBits();
844+
Name << getIntelSubgroupBlockDataPostfix(ElementBitSize, VectorNumElements);
845+
}
846+
mutateCallInst(CI, Name.str());
874847
}
875848

876849
void SPIRVToOCLBase::visitCallSPIRVAvcINTELEvaluateBuiltIn(CallInst *CI,

0 commit comments

Comments
 (0)