Skip to content

[SPIR-V] Improve general validity of emitted code between passes #119202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1633,8 +1633,14 @@ static bool generateICarryBorrowInst(const SPIRV::IncomingCall *Call,
}

MachineRegisterInfo *MRI = MIRBuilder.getMRI();
Register ResReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(ResReg, &SPIRV::iIDRegClass);
Register ResReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
if (const TargetRegisterClass *DstRC =
MRI->getRegClassOrNull(Call->Arguments[1])) {
MRI->setRegClass(ResReg, DstRC);
MRI->setType(ResReg, MRI->getType(Call->Arguments[1]));
} else {
MRI->setType(ResReg, LLT::scalar(64));
}
GR->assignSPIRVTypeToVReg(RetType, ResReg, MIRBuilder.getMF());
MIRBuilder.buildInstr(Opcode)
.addDef(ResReg)
Expand Down
231 changes: 129 additions & 102 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
}

SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {

return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
.addDef(createTypeVReg(MIRBuilder));
Expand Down Expand Up @@ -166,8 +165,11 @@ SPIRVType *SPIRVGlobalRegistry::createOpType(

auto LastInsertedType = LastInsertedTypeMap.find(CurMF);
if (LastInsertedType != LastInsertedTypeMap.end()) {
MIRBuilder.setInsertPt(*MIRBuilder.getMF().begin(),
LastInsertedType->second->getIterator());
auto It = LastInsertedType->second->getIterator();
auto NewMBB = MIRBuilder.getMF().begin();
MIRBuilder.setInsertPt(*NewMBB, It->getNextNode()
? It->getNextNode()->getIterator()
: NewMBB->end());
} else {
MIRBuilder.setInsertPt(*MIRBuilder.getMF().begin(),
MIRBuilder.getMF().begin()->begin());
Expand Down Expand Up @@ -269,24 +271,27 @@ Register SPIRVGlobalRegistry::getOrCreateConstFP(APFloat Val, MachineInstr &I,
// machine instruction, a new constant instruction should be created.
if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
return Res;
MachineInstrBuilder MIB;
MachineBasicBlock &BB = *I.getParent();
// In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
if (Val.isPosZero() && ZeroAsNull) {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
} else {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantF))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(
APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
MIB);
}
const auto &ST = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
*ST.getRegisterInfo(), *ST.getRegBankInfo());
MachineIRBuilder MIRBuilder(I);
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
MachineInstrBuilder MIB;
// In OpenCL OpConstantNull - Scalar floating point: +0.0 (all bits 0)
if (Val.isPosZero() && ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(
APInt(BitWidth, CI->getValueAPF().bitcastToAPInt().getZExtValue()),
MIB);
}
const auto &ST = CurMF->getSubtarget();
constrainSelectedInstRegOperands(
*MIB, *ST.getInstrInfo(), *ST.getRegisterInfo(), *ST.getRegBankInfo());
return MIB;
});
return Res;
}

Expand All @@ -305,21 +310,25 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
// machine instruction, a new constant instruction should be created.
if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
return Res;
MachineInstrBuilder MIB;
MachineBasicBlock &BB = *I.getParent();
if (Val || !ZeroAsNull) {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
}
const auto &ST = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
*ST.getRegisterInfo(), *ST.getRegBankInfo());

MachineIRBuilder MIRBuilder(I);
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
MachineInstrBuilder MIB;
if (Val || !ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
}
const auto &ST = CurMF->getSubtarget();
constrainSelectedInstRegOperands(
*MIB, *ST.getInstrInfo(), *ST.getRegisterInfo(), *ST.getRegBankInfo());
return MIB;
});
return Res;
}

Expand Down Expand Up @@ -347,21 +356,24 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
MIRBuilder.buildConstant(Res, *ConstInt);
} else {
Register SpvTypeReg = getSPIRVTypeID(SpvType);
MachineInstrBuilder MIB;
if (Val || !ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(SpvTypeReg);
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(SpvTypeReg);
}
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
*Subtarget.getRegisterInfo(),
*Subtarget.getRegBankInfo());
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
MachineInstrBuilder MIB;
if (Val || !ZeroAsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(SpvTypeReg);
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(SpvTypeReg);
}
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
*Subtarget.getRegisterInfo(),
*Subtarget.getRegBankInfo());
return MIB;
});
}
}
return Res;
Expand All @@ -385,12 +397,14 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, MF);
DT.add(ConstFP, &MF, Res);

MachineInstrBuilder MIB;
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
MachineInstrBuilder MIB;
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
return MIB;
});
}

return Res;
Expand Down Expand Up @@ -439,23 +453,26 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
CurMF->getRegInfo().setRegClass(SpvVecConst, getRegClass(SpvType));
assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
DT.add(CA, CurMF, SpvVecConst);
MachineInstrBuilder MIB;
MachineBasicBlock &BB = *I.getParent();
if (!IsNull) {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
for (unsigned i = 0; i < ElemCnt; ++i)
MIB.addUse(SpvScalConst);
} else {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
}
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
*Subtarget.getRegisterInfo(),
*Subtarget.getRegBankInfo());
MachineIRBuilder MIRBuilder(I);
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
MachineInstrBuilder MIB;
if (!IsNull) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
for (unsigned i = 0; i < ElemCnt; ++i)
MIB.addUse(SpvScalConst);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
}
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
*Subtarget.getRegisterInfo(),
*Subtarget.getRegBankInfo());
return MIB;
});
return SpvVecConst;
}
return Res;
Expand Down Expand Up @@ -544,17 +561,20 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
if (EmitIR) {
MIRBuilder.buildSplatBuildVector(SpvVecConst, SpvScalConst);
} else {
if (Val) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
for (unsigned i = 0; i < ElemCnt; ++i)
MIB.addUse(SpvScalConst);
} else {
MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
}
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
if (Val) {
auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
for (unsigned i = 0; i < ElemCnt; ++i)
MIB.addUse(SpvScalConst);
return MIB;
} else {
return MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(SpvVecConst)
.addUse(getSPIRVTypeID(SpvType));
}
});
}
return SpvVecConst;
}
Expand Down Expand Up @@ -592,9 +612,11 @@ SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(Res, &SPIRV::pIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
});
DT.add(CP, CurMF, Res);
}
return Res;
Expand All @@ -614,12 +636,14 @@ Register SPIRVGlobalRegistry::buildConstantSampler(
ResReg.isValid()
? ResReg
: MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass);
auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
.addDef(Sampler)
.addUse(getSPIRVTypeID(SampTy))
.addImm(AddrMode)
.addImm(Param)
.addImm(FilerMode);
auto Res = createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
return MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
.addDef(Sampler)
.addUse(getSPIRVTypeID(SampTy))
.addImm(AddrMode)
.addImm(Param)
.addImm(FilerMode);
});
assert(Res->getOperand(0).isReg());
return Res->getOperand(0).getReg();
}
Expand Down Expand Up @@ -1551,14 +1575,17 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
// create a new type
auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
MIRBuilder.getDebugLoc(),
MIRBuilder.getTII().get(SPIRV::OpTypePointer))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(static_cast<uint32_t>(SC))
.addUse(getSPIRVTypeID(BaseType));
DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
return finishCreatingSPIRVType(LLVMTy, MIB);
return createOpType(MIRBuilder, [&](MachineIRBuilder &MIRBuilder) {
auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
MIRBuilder.getDebugLoc(),
MIRBuilder.getTII().get(SPIRV::OpTypePointer))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addImm(static_cast<uint32_t>(SC))
.addUse(getSPIRVTypeID(BaseType));
DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
finishCreatingSPIRVType(LLVMTy, MIB);
return MIB;
});
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
Expand Down
Loading
Loading