Skip to content

Commit 44c70a8

Browse files
jcranmer-intelsvenvh
authored andcommitted
[Mutator] Migrate miscellaneous code in SPIRVToOCL to the mutator interface.
The test change is because use of an IRBuilder causes extra constant folding to happen that previously didn't happen.
1 parent ca73358 commit 44c70a8

File tree

2 files changed

+102
-132
lines changed

2 files changed

+102
-132
lines changed

lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 99 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -535,48 +535,39 @@ static bool needsInt32RetTy(Op OC) {
535535

536536
void SPIRVToOCLBase::visitCallSPIRVGroupBuiltin(CallInst *CI, Op OC) {
537537
auto FuncName = groupOCToOCLBuiltinName(CI, OC);
538-
auto ModifyArguments = [=](CallInst *, std::vector<Value *> &Args,
539-
llvm::Type *&RetTy) {
540-
Type *Int32Ty = Type::getInt32Ty(*Ctx);
541-
bool HasArg0ExtendedToi32 =
542-
OC == OpGroupAny || OC == OpGroupAll || OC == OpGroupNonUniformAny ||
543-
OC == OpGroupNonUniformAll || OC == OpGroupNonUniformBallot ||
544-
isGroupLogicalOpCode(OC);
545-
/// Remove Group Operation argument,
546-
/// as in OpenCL representation this is included in the function name
547-
Args.erase(Args.begin(), Args.begin() + (hasGroupOperation(OC) ? 2 : 1));
548-
549-
// Handle function arguments
550-
if (OC == OpGroupBroadcast)
551-
expandVector(CI, Args, 1);
552-
else if (HasArg0ExtendedToi32)
553-
Args[0] = CastInst::CreateZExtOrBitCast(Args[0], Int32Ty, "", CI);
554-
555-
// Handle function return type
556-
if (needsInt32RetTy(OC))
557-
RetTy = Int32Ty;
558-
559-
return FuncName;
560-
};
561-
auto ModifyRetTy = [=](CallInst *CI) -> Instruction * {
562-
if (needsInt32RetTy(OC)) {
538+
auto Mutator = mutateCallInst(CI, FuncName);
539+
/// Remove Group Operation argument,
540+
/// as in OpenCL representation this is included in the function name
541+
Mutator.removeArgs(0, (hasGroupOperation(OC) ? 2 : 1));
542+
543+
Type *Int32Ty = Type::getInt32Ty(*Ctx);
544+
bool HasArg0ExtendedToi32 =
545+
OC == OpGroupAny || OC == OpGroupAll || OC == OpGroupNonUniformAny ||
546+
OC == OpGroupNonUniformAll || OC == OpGroupNonUniformBallot ||
547+
isGroupLogicalOpCode(OC);
548+
549+
// Handle function arguments
550+
if (OC == OpGroupBroadcast) {
551+
Value *VecArg = Mutator.getArg(1);
552+
if (auto *VT = dyn_cast<FixedVectorType>(VecArg->getType())) {
553+
unsigned NumElements = VT->getNumElements();
554+
for (unsigned I = 0; I < NumElements; I++)
555+
Mutator.insertArg(1 + I, Mutator.Builder.CreateExtractElement(
556+
VecArg, Mutator.Builder.getInt32(I)));
557+
Mutator.removeArg(1 + NumElements);
558+
}
559+
} else if (HasArg0ExtendedToi32)
560+
Mutator.mapArg(0, [](IRBuilder<> &Builder, Value *V) {
561+
return Builder.CreateZExt(V, Builder.getInt32Ty());
562+
});
563+
564+
// Handle function return type
565+
if (needsInt32RetTy(OC))
566+
Mutator.changeReturnType(Int32Ty, [](IRBuilder<> &Builder, CallInst *CI) {
563567
// The OpenCL builtin returns a non-zero integer value. Convert to a
564568
// boolean value.
565-
Constant *Zero = ConstantInt::get(CI->getType(), 0);
566-
return new ICmpInst(CI->getNextNode(), CmpInst::ICMP_NE, CI, Zero);
567-
} else
568-
return CI;
569-
};
570-
571-
assert(CI->getCalledFunction() && "Unexpected indirect call");
572-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
573-
SmallVector<AttributeSet, 2> ArgAttrs;
574-
for (int I = (hasGroupOperation(OC) ? 2 : 1);
575-
I < (int)Attrs.getNumAttrSets() - 2; I++)
576-
ArgAttrs.push_back(Attrs.getParamAttrs(I));
577-
Attrs = AttributeList::get(*Ctx, Attrs.getFnAttrs(), Attrs.getRetAttrs(),
578-
ArgAttrs);
579-
mutateCallInstOCL(M, CI, ModifyArguments, ModifyRetTy, &Attrs);
569+
return Builder.CreateICmpNE(CI, Builder.getInt32(0));
570+
});
580571
}
581572

582573
void SPIRVToOCLBase::visitCallSPIRVPipeBuiltin(CallInst *CI, Op OC) {
@@ -586,64 +577,52 @@ void SPIRVToOCLBase::visitCallSPIRVPipeBuiltin(CallInst *CI, Op OC) {
586577
DemangledName = getGroupBuiltinPrefix(CI) + DemangledName;
587578

588579
assert(CI->getCalledFunction() && "Unexpected indirect call");
589-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
590-
mutateCallInstOCL(
591-
M, CI,
592-
[=](CallInst *, std::vector<Value *> &Args) {
593-
if (HasScope)
594-
Args.erase(Args.begin(), Args.begin() + 1);
595-
596-
if (!(OC == OpReadPipe || OC == OpWritePipe ||
597-
OC == OpReservedReadPipe || OC == OpReservedWritePipe ||
598-
OC == OpReadPipeBlockingINTEL || OC == OpWritePipeBlockingINTEL))
599-
return DemangledName;
600-
601-
auto &P = Args[Args.size() - 3];
602-
auto T = P->getType();
603-
assert(isa<PointerType>(T));
604-
auto *NewTy = PointerType::getInt8PtrTy(*Ctx, SPIRAS_Generic);
605-
if (T != NewTy) {
606-
P = CastInst::CreatePointerBitCastOrAddrSpaceCast(P, NewTy, "", CI);
607-
}
608-
return DemangledName;
609-
},
610-
&Attrs);
580+
auto Mutator = mutateCallInst(CI, DemangledName);
581+
if (HasScope)
582+
Mutator.removeArg(0);
583+
if (OC == OpReadPipe || OC == OpWritePipe || OC == OpReservedReadPipe ||
584+
OC == OpReservedWritePipe || OC == OpReadPipeBlockingINTEL ||
585+
OC == OpWritePipeBlockingINTEL) {
586+
Mutator.mapArg(Mutator.arg_size() - 3, [](IRBuilder<> &Builder, Value *P) {
587+
Type *T = P->getType();
588+
assert(isa<PointerType>(T));
589+
auto *NewTy = Builder.getInt8PtrTy(SPIRAS_Generic);
590+
if (T != NewTy) {
591+
P = Builder.CreatePointerBitCastOrAddrSpaceCast(P, NewTy);
592+
}
593+
return std::pair<Value *, Type *>(P, Builder.getInt8Ty());
594+
});
595+
}
611596
}
612597

613598
void SPIRVToOCLBase::visitCallSPIRVImageMediaBlockBuiltin(CallInst *CI, Op OC) {
614-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
615-
mutateCallInstOCL(
616-
M, CI,
617-
[=](CallInst *, std::vector<Value *> &Args) {
618-
// Moving the first argument to the end.
619-
std::rotate(Args.rbegin(), Args.rend() - 1, Args.rend());
620-
Type *RetType = CI->getType();
621-
if (OC == OpSubgroupImageMediaBlockWriteINTEL) {
622-
assert(Args.size() >= 4 && "Wrong media block write signature");
623-
RetType = Args.at(3)->getType(); // texel type
624-
}
625-
unsigned int BitWidth = RetType->getScalarSizeInBits();
626-
std::string FuncPostfix;
627-
if (BitWidth == 8)
628-
FuncPostfix = "_uc";
629-
else if (BitWidth == 16)
630-
FuncPostfix = "_us";
631-
else if (BitWidth == 32)
632-
FuncPostfix = "_ui";
633-
else
634-
assert(0 && "Unsupported texel type!");
599+
Type *RetType = CI->getType();
600+
if (OC == OpSubgroupImageMediaBlockWriteINTEL) {
601+
assert(CI->arg_size() >= 5 && "Wrong media block write signature");
602+
RetType = CI->getArgOperand(4)->getType(); // texel type
603+
}
604+
unsigned int BitWidth = RetType->getScalarSizeInBits();
605+
std::string FuncPostfix;
606+
if (BitWidth == 8)
607+
FuncPostfix = "_uc";
608+
else if (BitWidth == 16)
609+
FuncPostfix = "_us";
610+
else if (BitWidth == 32)
611+
FuncPostfix = "_ui";
612+
else
613+
assert(0 && "Unsupported texel type!");
635614

636-
if (auto *VecTy = dyn_cast<FixedVectorType>(RetType)) {
637-
unsigned int NumEl = VecTy->getNumElements();
638-
assert((NumEl == 2 || NumEl == 4 || NumEl == 8 || NumEl == 16) &&
639-
"Wrong function type!");
640-
FuncPostfix += std::to_string(NumEl);
641-
}
615+
if (auto *VecTy = dyn_cast<FixedVectorType>(RetType)) {
616+
unsigned int NumEl = VecTy->getNumElements();
617+
assert((NumEl == 2 || NumEl == 4 || NumEl == 8 || NumEl == 16) &&
618+
"Wrong function type!");
619+
FuncPostfix += std::to_string(NumEl);
620+
}
642621

643-
return OCLSPIRVBuiltinMap::rmap(OC) + FuncPostfix;
644-
},
645-
&Attrs);
622+
mutateCallInst(CI, OCLSPIRVBuiltinMap::rmap(OC) + FuncPostfix)
623+
.moveArg(0, CI->arg_size() - 1);
646624
}
625+
647626
void SPIRVToOCLBase::visitCallBuildNDRangeBuiltIn(CallInst *CI, Op OC,
648627
StringRef DemangledName) {
649628
assert(CI->getCalledFunction() && "Unexpected indirect call");
@@ -837,40 +816,34 @@ void SPIRVToOCLBase::visitCallSPIRVImageQueryBuiltIn(CallInst *CI, Op OC) {
837816
}
838817

839818
void SPIRVToOCLBase::visitCallSPIRVSubgroupINTELBuiltIn(CallInst *CI, Op OC) {
840-
assert(CI->getCalledFunction() && "Unexpected indirect call");
841-
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
842-
mutateCallInstOCL(
843-
M, CI,
844-
[=](CallInst *, std::vector<Value *> &Args) {
845-
std::stringstream Name;
846-
Type *DataTy = nullptr;
847-
switch (OC) {
848-
case OpSubgroupBlockReadINTEL:
849-
case OpSubgroupImageBlockReadINTEL:
850-
Name << "intel_sub_group_block_read";
851-
DataTy = CI->getType();
852-
break;
853-
case OpSubgroupBlockWriteINTEL:
854-
Name << "intel_sub_group_block_write";
855-
DataTy = CI->getOperand(1)->getType();
856-
break;
857-
case OpSubgroupImageBlockWriteINTEL:
858-
Name << "intel_sub_group_block_write";
859-
DataTy = CI->getOperand(2)->getType();
860-
break;
861-
default:
862-
return OCLSPIRVBuiltinMap::rmap(OC);
863-
}
864-
assert(DataTy && "Intel subgroup block builtins should have data type");
865-
unsigned VectorNumElements = 1;
866-
if (FixedVectorType *VT = dyn_cast<FixedVectorType>(DataTy))
867-
VectorNumElements = VT->getNumElements();
868-
unsigned ElementBitSize = DataTy->getScalarSizeInBits();
869-
Name << getIntelSubgroupBlockDataPostfix(ElementBitSize,
870-
VectorNumElements);
871-
return Name.str();
872-
},
873-
&Attrs);
819+
std::stringstream Name;
820+
Type *DataTy = nullptr;
821+
switch (OC) {
822+
case OpSubgroupBlockReadINTEL:
823+
case OpSubgroupImageBlockReadINTEL:
824+
Name << "intel_sub_group_block_read";
825+
DataTy = CI->getType();
826+
break;
827+
case OpSubgroupBlockWriteINTEL:
828+
Name << "intel_sub_group_block_write";
829+
DataTy = CI->getOperand(1)->getType();
830+
break;
831+
case OpSubgroupImageBlockWriteINTEL:
832+
Name << "intel_sub_group_block_write";
833+
DataTy = CI->getOperand(2)->getType();
834+
break;
835+
default:
836+
Name << OCLSPIRVBuiltinMap::rmap(OC);
837+
break;
838+
}
839+
if (DataTy) {
840+
unsigned VectorNumElements = 1;
841+
if (FixedVectorType *VT = dyn_cast<FixedVectorType>(DataTy))
842+
VectorNumElements = VT->getNumElements();
843+
unsigned ElementBitSize = DataTy->getScalarSizeInBits();
844+
Name << getIntelSubgroupBlockDataPostfix(ElementBitSize, VectorNumElements);
845+
}
846+
mutateCallInst(CI, Name.str());
874847
}
875848

876849
void SPIRVToOCLBase::visitCallSPIRVAvcINTELEvaluateBuiltIn(CallInst *CI,

test/transcoding/SPV_KHR_uniform_group_instructions/group-instructions.ll

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,9 @@
3939
; CHECK-LLVM: call spir_func i32 @_Z29work_group_reduce_bitwise_andi(i32 0)
4040
; CHECK-LLVM: call spir_func i32 @_Z28work_group_reduce_bitwise_ori(i32 0)
4141
; CHECK-LLVM: call spir_func i32 @_Z29work_group_reduce_bitwise_xori(i32 0)
42-
; CHECK-LLVM: [[INIT_I32_0:%.*]] = zext i1 false to i32
43-
; CHECK-LLVM: call spir_func i32 @_Z29work_group_reduce_logical_andi(i32 [[INIT_I32_0]])
44-
; CHECK-LLVM: [[INIT_I32_1:%.*]] = zext i1 false to i32
45-
; CHECK-LLVM: call spir_func i32 @_Z28work_group_reduce_logical_ori(i32 [[INIT_I32_1]])
46-
; CHECK-LLVM: [[INIT_I32_2:%.*]] = zext i1 false to i32
47-
; CHECK-LLVM: call spir_func i32 @_Z29work_group_reduce_logical_xori(i32 [[INIT_I32_2]])
42+
; CHECK-LLVM: call spir_func i32 @_Z29work_group_reduce_logical_andi(i32 0)
43+
; CHECK-LLVM: call spir_func i32 @_Z28work_group_reduce_logical_ori(i32 0)
44+
; CHECK-LLVM: call spir_func i32 @_Z29work_group_reduce_logical_xori(i32 0)
4845
; CHECK-LLVM: call spir_func i32 @_Z21work_group_reduce_muli(i32 0)
4946
; CHECK-LLVM: call spir_func half @_Z21work_group_reduce_mulDh(half 0xH0000)
5047

0 commit comments

Comments
 (0)