Skip to content

[SPIR-V] Implement builtins for OpIAddCarry/OpISubBorrow and improve/fix type inference #115192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 64 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,14 @@ std::string lookupBuiltinNameHelper(StringRef DemangledCall) {
// Check if the extracted name contains type information between angle
// brackets. If so, the builtin is an instantiated template - needs to have
// the information after angle brackets and return type removed.
if (BuiltinName.find('<') && BuiltinName.back() == '>') {
BuiltinName = BuiltinName.substr(0, BuiltinName.find('<'));
std::size_t Pos1 = BuiltinName.rfind('<');
if (Pos1 != std::string::npos && BuiltinName.back() == '>') {
std::size_t Pos2 = BuiltinName.rfind(' ', Pos1);
if (Pos2 == std::string::npos)
Pos2 = 0;
else
++Pos2;
BuiltinName = BuiltinName.substr(Pos2, Pos1 - Pos2);
BuiltinName = BuiltinName.substr(BuiltinName.find_last_of(' ') + 1);
}

Expand Down Expand Up @@ -461,9 +467,11 @@ static Register buildBuiltinVariableLoad(
SPIRVGlobalRegistry *GR, SPIRV::BuiltIn::BuiltIn BuiltinValue, LLT LLType,
Register Reg = Register(0), bool isConst = true, bool hasLinkageTy = true) {
Register NewRegister =
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
MIRBuilder.getMRI()->setType(NewRegister,
LLT::pointer(0, GR->getPointerSize()));
MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::pIDRegClass);
MIRBuilder.getMRI()->setType(
NewRegister,
LLT::pointer(storageClassToAddressSpace(SPIRV::StorageClass::Function),
GR->getPointerSize()));
SPIRVType *PtrType = GR->getOrCreateSPIRVPointerType(
VariableType, MIRBuilder, SPIRV::StorageClass::Input);
GR->assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF());
Expand Down Expand Up @@ -1556,6 +1564,55 @@ static bool generateWaveInst(const SPIRV::IncomingCall *Call,
/* isConst= */ false, /* hasLinkageTy= */ false);
}

// We expect a builtin
// Name(ptr sret([RetType]) %result, Type %operand1, Type %operand1)
// where %result is a pointer to where the result of the builtin execution
// is to be stored, and generate the following instructions:
// Res = Opcode RetType Operand1 Operand1
// OpStore RetVariable Res
static bool generateICarryBorrowInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
unsigned Opcode =
SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;

Register SRetReg = Call->Arguments[0];
SPIRVType *PtrRetType = GR->getSPIRVTypeForVReg(SRetReg);
SPIRVType *RetType = GR->getPointeeType(PtrRetType);
if (!RetType)
report_fatal_error("The first parameter must be a pointer");
if (RetType->getOpcode() != SPIRV::OpTypeStruct)
report_fatal_error("Expected struct type result for the arithmetic with "
"overflow builtins");

SPIRVType *OpType1 = GR->getSPIRVTypeForVReg(Call->Arguments[1]);
SPIRVType *OpType2 = GR->getSPIRVTypeForVReg(Call->Arguments[2]);
if (!OpType1 || !OpType2 || OpType1 != OpType2)
report_fatal_error("Operands must have the same type");
if (OpType1->getOpcode() == SPIRV::OpTypeVector)
switch (Opcode) {
case SPIRV::OpIAddCarryS:
Opcode = SPIRV::OpIAddCarryV;
break;
case SPIRV::OpISubBorrowS:
Opcode = SPIRV::OpISubBorrowV;
break;
}

MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register ResReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(ResReg, &SPIRV::iIDRegClass);
GR->assignSPIRVTypeToVReg(RetType, ResReg, MIRBuilder.getMF());
MIRBuilder.buildInstr(Opcode)
.addDef(ResReg)
.addUse(GR->getSPIRVTypeID(RetType))
.addUse(Call->Arguments[1])
.addUse(Call->Arguments[2]);
MIRBuilder.buildInstr(SPIRV::OpStore).addUse(SRetReg).addUse(ResReg);
return true;
}

static bool generateGetQueryInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
Expand Down Expand Up @@ -2511,6 +2568,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateDotOrFMulInst(Call.get(), MIRBuilder, GR);
case SPIRV::Wave:
return generateWaveInst(Call.get(), MIRBuilder, GR);
case SPIRV::ICarryBorrow:
return generateICarryBorrowInst(Call.get(), MIRBuilder, GR);
case SPIRV::GetQuery:
return generateGetQueryInst(Call.get(), MIRBuilder, GR);
case SPIRV::ImageSizeQuery:
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def KernelClock : BuiltinGroup;
def CastToPtr : BuiltinGroup;
def Construct : BuiltinGroup;
def CoopMatr : BuiltinGroup;
def ICarryBorrow : BuiltinGroup;

//===----------------------------------------------------------------------===//
// Class defining a demangled builtin record. The information in the record
Expand Down Expand Up @@ -628,6 +629,10 @@ defm : DemangledNativeBuiltin<"barrier", OpenCL_std, Barrier, 1, 3, OpControlBar
defm : DemangledNativeBuiltin<"work_group_barrier", OpenCL_std, Barrier, 1, 3, OpControlBarrier>;
defm : DemangledNativeBuiltin<"__spirv_ControlBarrier", OpenCL_std, Barrier, 3, 3, OpControlBarrier>;

// ICarryBorrow builtin record:
defm : DemangledNativeBuiltin<"__spirv_IAddCarry", OpenCL_std, ICarryBorrow, 3, 3, OpIAddCarryS>;
defm : DemangledNativeBuiltin<"__spirv_ISubBorrow", OpenCL_std, ICarryBorrow, 3, 3, OpISubBorrowS>;

// cl_intel_split_work_group_barrier
defm : DemangledNativeBuiltin<"intel_work_group_barrier_arrive", OpenCL_std, Barrier, 1, 2, OpControlBarrierArriveINTEL>;
defm : DemangledNativeBuiltin<"__spirv_ControlBarrierArriveINTEL", OpenCL_std, Barrier, 3, 3, OpControlBarrierArriveINTEL>;
Expand Down
32 changes: 28 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,36 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
ArgVRegs.push_back(ArgReg);
SPIRVType *SpvType = GR->getSPIRVTypeForVReg(ArgReg);
if (!SpvType) {
SpvType = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
Type *ArgTy = nullptr;
if (auto *PtrArgTy = dyn_cast<PointerType>(Arg.Ty)) {
// If Arg.Ty is an untyped pointer (i.e., ptr [addrspace(...)]) and we
// don't have access to original value in LLVM IR or info about
// deduced pointee type, then we should wait with setting the type for
// the virtual register until pre-legalizer step when we access
// @llvm.spv.assign.ptr.type.p...(...)'s info.
if (Arg.OrigValue)
if (Type *ElemTy = GR->findDeducedElementType(Arg.OrigValue))
ArgTy =
TypedPointerType::get(ElemTy, PtrArgTy->getAddressSpace());
} else {
ArgTy = Arg.Ty;
}
if (ArgTy) {
SpvType = GR->getOrCreateSPIRVType(ArgTy, MIRBuilder);
GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
}
}
if (!MRI->getRegClassOrNull(ArgReg)) {
MRI->setRegClass(ArgReg, GR->getRegClass(SpvType));
MRI->setType(ArgReg, GR->getRegType(SpvType));
// Either we have SpvType created, or Arg.Ty is an untyped pointer and
// we know its virtual register's class and type even if we don't know
// pointee type.
MRI->setRegClass(ArgReg, SpvType ? GR->getRegClass(SpvType)
: &SPIRV::pIDRegClass);
MRI->setType(
ArgReg,
SpvType ? GR->getRegType(SpvType)
: LLT::pointer(cast<PointerType>(Arg.Ty)->getAddressSpace(),
GR->getPointerSize()));
}
}
auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
Expand Down
50 changes: 37 additions & 13 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ class SPIRVEmitIntrinsics
SPIRV::InstructionSet::InstructionSet InstrSet;

// a register of Instructions that don't have a complete type definition
SmallPtrSet<Value *, 8> UncompleteTypeInfo;
SmallVector<Instruction *> PostprocessWorklist;
DenseMap<Value *, unsigned> UncompleteTypeInfo;
SmallVector<Value *> PostprocessWorklist;

// well known result types of builtins
enum WellKnownTypes { Event };
Expand Down Expand Up @@ -147,6 +147,7 @@ class SPIRVEmitIntrinsics
std::unordered_set<Function *> &FVisited);
void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
CallInst *AssignCI);
void replaceAllUsesWith(Value *Src, Value *Dest, bool DeleteOld = true);

bool runOnFunction(Function &F);
bool postprocessTypes();
Expand Down Expand Up @@ -272,6 +273,27 @@ static inline void reportFatalOnTokenType(const Instruction *I) {
false);
}

void SPIRVEmitIntrinsics::replaceAllUsesWith(Value *Src, Value *Dest,
bool DeleteOld) {
Src->replaceAllUsesWith(Dest);
// Update deduced type records
GR->updateIfExistDeducedElementType(Src, Dest, DeleteOld);
GR->updateIfExistAssignPtrTypeInstr(Src, Dest, DeleteOld);
// Update uncomplete type records if any
auto It = UncompleteTypeInfo.find(Src);
if (It == UncompleteTypeInfo.end())
return;
if (DeleteOld) {
unsigned Pos = It->second;
UncompleteTypeInfo.erase(Src);
UncompleteTypeInfo[Dest] = Pos;
PostprocessWorklist[Pos] = Dest;
} else {
UncompleteTypeInfo[Dest] = PostprocessWorklist.size();
PostprocessWorklist.push_back(Dest);
}
}

static bool IsKernelArgInt8(Function *F, StoreInst *SI) {
return SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
isPointerTy(SI->getValueOperand()->getType()) &&
Expand Down Expand Up @@ -434,7 +456,7 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
if (!UnknownElemTypeI8)
return;
if (auto *I = dyn_cast<Instruction>(Op)) {
UncompleteTypeInfo.insert(I);
UncompleteTypeInfo[I] = PostprocessWorklist.size();
PostprocessWorklist.push_back(I);
}
}
Expand Down Expand Up @@ -640,7 +662,7 @@ Type *SPIRVEmitIntrinsics::deduceElementType(Value *I, bool UnknownElemTypeI8) {
if (!UnknownElemTypeI8)
return nullptr;
if (auto *Instr = dyn_cast<Instruction>(I)) {
UncompleteTypeInfo.insert(Instr);
UncompleteTypeInfo[Instr] = PostprocessWorklist.size();
PostprocessWorklist.push_back(Instr);
}
return IntegerType::getInt8Ty(I->getContext());
Expand Down Expand Up @@ -1062,7 +1084,7 @@ Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
{I.getOperand(0)->getType()}, {Args});
// remove switch to avoid its unneeded and undesirable unwrap into branches
// and conditions
I.replaceAllUsesWith(NewI);
replaceAllUsesWith(&I, NewI);
I.eraseFromParent();
// insert artificial and temporary instruction to preserve valid CFG,
// it will be removed after IR translation pass
Expand All @@ -1084,7 +1106,7 @@ Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) {
for (auto &Op : I.operands())
Args.push_back(Op);
auto *NewI = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
I.replaceAllUsesWith(NewI);
replaceAllUsesWith(&I, NewI);
I.eraseFromParent();
return NewI;
}
Expand All @@ -1099,7 +1121,7 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
// such bitcasts do not provide sufficient information, should be just skipped
// here, and handled in insertPtrCastOrAssignTypeInstr.
if (isPointerTy(I.getType())) {
I.replaceAllUsesWith(Source);
replaceAllUsesWith(&I, Source);
I.eraseFromParent();
return nullptr;
}
Expand All @@ -1108,7 +1130,7 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
SmallVector<Value *> Args(I.op_begin(), I.op_end());
auto *NewI = B.CreateIntrinsic(Intrinsic::spv_bitcast, {Types}, {Args});
std::string InstName = I.hasName() ? I.getName().str() : "";
I.replaceAllUsesWith(NewI);
replaceAllUsesWith(&I, NewI);
I.eraseFromParent();
NewI->setName(InstName);
return NewI;
Expand Down Expand Up @@ -1219,6 +1241,8 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
SmallVector<Value *, 2> Args = {Pointer, VMD, B.getInt32(AddressSpace)};
auto *PtrCastI = B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
I->setOperand(OperandToReplace, PtrCastI);
// We need to set up a pointee type for the newly created spv_ptrcast.
buildAssignPtr(B, ExpectedElementType, PtrCastI);
}

void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
Expand Down Expand Up @@ -1331,7 +1355,7 @@ Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) {
SmallVector<Value *> Args(I.op_begin(), I.op_end());
auto *NewI = B.CreateIntrinsic(Intrinsic::spv_insertelt, {Types}, {Args});
std::string InstName = I.hasName() ? I.getName().str() : "";
I.replaceAllUsesWith(NewI);
replaceAllUsesWith(&I, NewI);
I.eraseFromParent();
NewI->setName(InstName);
return NewI;
Expand All @@ -1346,7 +1370,7 @@ SPIRVEmitIntrinsics::visitExtractElementInst(ExtractElementInst &I) {
SmallVector<Value *, 2> Args = {I.getVectorOperand(), I.getIndexOperand()};
auto *NewI = B.CreateIntrinsic(Intrinsic::spv_extractelt, {Types}, {Args});
std::string InstName = I.hasName() ? I.getName().str() : "";
I.replaceAllUsesWith(NewI);
replaceAllUsesWith(&I, NewI);
I.eraseFromParent();
NewI->setName(InstName);
return NewI;
Expand Down Expand Up @@ -1382,7 +1406,7 @@ Instruction *SPIRVEmitIntrinsics::visitExtractValueInst(ExtractValueInst &I) {
Args.push_back(B.getInt32(Op));
auto *NewI =
B.CreateIntrinsic(Intrinsic::spv_extractv, {I.getType()}, {Args});
I.replaceAllUsesWith(NewI);
replaceAllUsesWith(&I, NewI);
I.eraseFromParent();
return NewI;
}
Expand Down Expand Up @@ -1443,7 +1467,7 @@ Instruction *SPIRVEmitIntrinsics::visitAllocaInst(AllocaInst &I) {
{PtrTy, ArraySize->getType()}, {ArraySize})
: B.CreateIntrinsic(Intrinsic::spv_alloca, {PtrTy}, {});
std::string InstName = I.hasName() ? I.getName().str() : "";
I.replaceAllUsesWith(NewI);
replaceAllUsesWith(&I, NewI);
I.eraseFromParent();
NewI->setName(InstName);
return NewI;
Expand Down Expand Up @@ -1613,7 +1637,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I,
auto *NewOp =
buildIntrWithMD(Intrinsic::spv_track_constant,
{II->getType(), II->getType()}, t->second, I, {}, B);
I->replaceAllUsesWith(NewOp);
replaceAllUsesWith(I, NewOp, false);
NewOp->setArgOperand(0, I);
}
bool IsPhi = isa<PHINode>(I), BPrepared = false;
Expand Down
18 changes: 18 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ class SPIRVGlobalRegistry {
auto It = AssignPtrTypeInstr.find(Val);
return It == AssignPtrTypeInstr.end() ? nullptr : It->second;
}
// - Find a record and update its key or add a new record, if found.
void updateIfExistAssignPtrTypeInstr(Value *OldVal, Value *NewVal,
bool DeleteOld) {
if (CallInst *CI = findAssignPtrTypeInstr(OldVal)) {
if (DeleteOld)
AssignPtrTypeInstr.erase(OldVal);
AssignPtrTypeInstr[NewVal] = CI;
}
}

// A registry of mutated values
// (see `SPIRVPrepareFunctions::removeAggregateTypesFromSignature()`):
Expand Down Expand Up @@ -214,6 +223,15 @@ class SPIRVGlobalRegistry {
auto It = DeducedElTys.find(Val);
return It == DeducedElTys.end() ? nullptr : It->second;
}
// - Find a record and update its key or add a new record, if found.
void updateIfExistDeducedElementType(Value *OldVal, Value *NewVal,
bool DeleteOld) {
if (Type *Ty = findDeducedElementType(OldVal)) {
if (DeleteOld)
DeducedElTys.erase(OldVal);
DeducedElTys[NewVal] = Ty;
}
}
// - Add a record to the map of deduced composite types.
void addDeducedCompositeType(Value *Val, Type *Ty) {
DeducedNestedTys[Val] = Ty;
Expand Down
18 changes: 15 additions & 3 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,17 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
validateForwardCalls(STI, MRI, GR, MI);
break;

// ensure that LLVM IR add/sub instructions result in logical SPIR-V
// instructions when applied to bool type
case SPIRV::OpIAddS:
case SPIRV::OpIAddV:
case SPIRV::OpISubS:
case SPIRV::OpISubV:
if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(),
SPIRV::OpTypeBool))
MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual));
break;

// ensure that LLVM IR bitwise instructions result in logical SPIR-V
// instructions when applied to bool type
case SPIRV::OpBitwiseOrS:
Expand Down Expand Up @@ -473,8 +484,11 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
MI.getOperand(2).getImm() != SPIRV::InstructionSet::OpenCL_std)
continue;
switch (MI.getOperand(3).getImm()) {
case SPIRV::OpenCLExtInst::frexp:
case SPIRV::OpenCLExtInst::lgamma_r:
case SPIRV::OpenCLExtInst::remquo: {
// The last operand must be of a pointer to the return type.
// The last operand must be of a pointer to i32 or vector of i32
// values.
MachineIRBuilder MIB(MI);
SPIRVType *Int32Type = GR.getOrCreateSPIRVIntegerType(32, MIB);
SPIRVType *RetType = MRI->getVRegDef(MI.getOperand(1).getReg());
Expand All @@ -487,8 +501,6 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
Int32Type, RetType->getOperand(2).getImm(), MIB));
} break;
case SPIRV::OpenCLExtInst::fract:
case SPIRV::OpenCLExtInst::frexp:
case SPIRV::OpenCLExtInst::lgamma_r:
case SPIRV::OpenCLExtInst::modf:
case SPIRV::OpenCLExtInst::sincos:
// The last operand must be of a pointer to the base type represented
Expand Down
Loading
Loading