Skip to content

[SPIR-V] Re-implement switch and improve validation of forward calls #87823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ using namespace llvm;

namespace {
class SPIRVAsmPrinter : public AsmPrinter {
unsigned NLabels = 0;

public:
explicit SPIRVAsmPrinter(TargetMachine &TM,
std::unique_ptr<MCStreamer> Streamer)
Expand Down Expand Up @@ -109,10 +111,9 @@ void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) {
uint32_t DecSPIRVVersion = ST->getSPIRVVersion();
uint32_t Major = DecSPIRVVersion / 10;
uint32_t Minor = DecSPIRVVersion - Major * 10;
// TODO: calculate Bound more carefully from maximum used register number,
// accounting for generated OpLabels and other related instructions if
// needed.
unsigned Bound = 2 * (ST->getBound() + 1);
// Bound is an approximation that accounts for the maximum used register
// number and number of generated OpLabels
unsigned Bound = 2 * (ST->getBound() + 1) + NLabels;
bool FlagToRestore = OutStreamer->getUseAssemblerInfoForParsing();
OutStreamer->setUseAssemblerInfoForParsing(true);
if (MCAssembler *Asm = OutStreamer->getAssemblerPtr())
Expand Down Expand Up @@ -158,6 +159,7 @@ void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
LabelInst.setOpcode(SPIRV::OpLabel);
LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));
outputMCInst(LabelInst);
++NLabels;
}

void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
Expand Down
35 changes: 28 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,15 +460,36 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
}

Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
IRBuilder<> B(I.getParent());
BasicBlock *ParentBB = I.getParent();
IRBuilder<> B(ParentBB);
B.SetInsertPoint(&I);
SmallVector<Value *, 4> Args;
for (auto &Op : I.operands())
if (Op.get()->getType()->isSized())
SmallVector<BasicBlock *> BBCases;
for (auto &Op : I.operands()) {
if (Op.get()->getType()->isSized()) {
Args.push_back(Op);
B.SetInsertPoint(&I);
B.CreateIntrinsic(Intrinsic::spv_switch, {I.getOperand(0)->getType()},
{Args});
return &I;
} else if (BasicBlock *BB = dyn_cast<BasicBlock>(Op.get())) {
BBCases.push_back(BB);
Args.push_back(BlockAddress::get(BB->getParent(), BB));
} else {
report_fatal_error("Unexpected switch operand");
}
}
CallInst *NewI = B.CreateIntrinsic(Intrinsic::spv_switch,
{I.getOperand(0)->getType()}, {Args});
// remove switch to avoid its unneeded and undesirable unwrap into branches
// and conditions
I.replaceAllUsesWith(NewI);
I.eraseFromParent();
// insert artificial and temporary instruction to preserve valid CFG,
// it will be removed after IR translation pass
B.SetInsertPoint(ParentBB);
IndirectBrInst *BrI = B.CreateIndirectBr(
Constant::getNullValue(PointerType::getUnqual(ParentBB->getContext())),
BBCases.size());
for (BasicBlock *BBCase : BBCases)
BrI->addDestination(BBCase);
return BrI;
}

Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) {
Expand Down
7 changes: 6 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,12 @@ class SPIRVGlobalRegistry {
// Return the VReg holding the result of the given OpTypeXXX instruction.
Register getSPIRVTypeID(const SPIRVType *SpirvType) const;

void setCurrentFunc(MachineFunction &MF) { CurMF = &MF; }
// Return previous value of the current machine function
MachineFunction *setCurrentFunc(MachineFunction &MF) {
MachineFunction *Ret = CurMF;
CurMF = &MF;
return Ret;
}

// Whether the given VReg has an OpTypeXXX instruction mapped to it with the
// given opcode (e.g. OpTypeFloat).
Expand Down
17 changes: 13 additions & 4 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,15 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
: nullptr;
if (DefElemType) {
const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
// Switch GR context to the call site instead of the (default) definition
// side
GR.setCurrentFunc(*FunCall.getParent()->getParent());
// validatePtrTypes() works in the context if the call site
// When we process historical records about forward calls
// we need to switch context to the (forward) call site and
// then restore it back to the current machine function.
MachineFunction *CurMF =
GR.setCurrentFunc(*FunCall.getParent()->getParent());
validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
DefElemTy);
GR.setCurrentFunc(*FunDef->getParent()->getParent());
GR.setCurrentFunc(*CurMF);
}
}
}
Expand Down Expand Up @@ -215,6 +218,11 @@ void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
// TODO: the logic of inserting additional bitcast's is to be moved
// to pre-IRTranslation passes eventually
void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
// finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
// We'd like to avoid the needless second processing pass.
if (ProcessedMF.find(&MF) != ProcessedMF.end())
return;

MachineRegisterInfo *MRI = &MF.getRegInfo();
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
GR.setCurrentFunc(MF);
Expand Down Expand Up @@ -302,5 +310,6 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
}
}
}
ProcessedMF.insert(&MF);
TargetLowering::finalizeLowering(MF);
}
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@

#include "SPIRVGlobalRegistry.h"
#include "llvm/CodeGen/TargetLowering.h"
#include <set>

namespace llvm {
class SPIRVSubtarget;

class SPIRVTargetLowering : public TargetLowering {
const SPIRVSubtarget &STI;

// Record of already processed machine functions
mutable std::set<const MachineFunction *> ProcessedMF;

public:
explicit SPIRVTargetLowering(const TargetMachine &TM,
const SPIRVSubtarget &ST)
Expand Down
223 changes: 56 additions & 167 deletions llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,186 +438,75 @@ static void processInstrsWithTypeFolding(MachineFunction &MF,
}
}

// Find basic blocks of the switch and replace registers in spv_switch() by its
// MBB equivalent.
static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
MachineIRBuilder MIB) {
// Before IRTranslator pass, calls to spv_switch intrinsic are inserted before
// each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND
// + G_BR triples. A switch with two cases may be transformed to this MIR
// sequence:
//
// intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
// %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
// G_BRCOND %Dst0, %bb.2
// G_BR %bb.5
// bb.5.entry:
// %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
// G_BRCOND %Dst1, %bb.3
// G_BR %bb.4
// bb.2.sw.bb:
// ...
// bb.3.sw.bb1:
// ...
// bb.4.sw.epilog:
// ...
//
// Sometimes (in case of range-compare switches), additional G_SUBs
// instructions are inserted before G_ICMPs. Those need to be additionally
// processed.
//
// This function modifies spv_switch call's operands to include destination
// MBBs (default and for each constant value).
//
// At the end, the function removes redundant [G_SUB] + G_ICMP + G_BRCOND +
// G_BR sequences.

MachineRegisterInfo &MRI = MF.getRegInfo();

// Collect spv_switches and G_ICMPs across all MBBs in MF.
std::vector<MachineInstr *> RelevantInsts;

// Collect redundant MIs from [G_SUB] + G_ICMP + G_BRCOND + G_BR sequences.
// After updating spv_switches, the instructions can be removed.
std::vector<MachineInstr *> PostUpdateArtifacts;

// Temporary set of compare registers. G_SUBs and G_ICMPs relating to
// spv_switch use these registers.
DenseSet<Register> CompareRegs;
DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
SmallVector<std::pair<MachineInstr *, SmallVector<MachineInstr *, 8>>>
Switches;
for (MachineBasicBlock &MBB : MF) {
MachineRegisterInfo &MRI = MF.getRegInfo();
BB2MBB[MBB.getBasicBlock()] = &MBB;
for (MachineInstr &MI : MBB) {
if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
continue;
// Calls to spv_switch intrinsics representing IR switches.
if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
assert(MI.getOperand(1).isReg());
CompareRegs.insert(MI.getOperand(1).getReg());
RelevantInsts.push_back(&MI);
}

// G_SUBs coming from range-compare switch lowering. G_SUBs are found
// after spv_switch but before G_ICMP.
if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() &&
CompareRegs.contains(MI.getOperand(1).getReg())) {
assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg());
Register Dst = MI.getOperand(0).getReg();
CompareRegs.insert(Dst);
PostUpdateArtifacts.push_back(&MI);
}

// G_ICMPs relating to switches.
if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
CompareRegs.contains(MI.getOperand(2).getReg())) {
Register Dst = MI.getOperand(0).getReg();
RelevantInsts.push_back(&MI);
PostUpdateArtifacts.push_back(&MI);
MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
assert(CBr->getOpcode() == SPIRV::G_BRCOND);
PostUpdateArtifacts.push_back(CBr);
MachineInstr *Br = CBr->getNextNode();
assert(Br->getOpcode() == SPIRV::G_BR);
PostUpdateArtifacts.push_back(Br);
SmallVector<MachineInstr *, 8> NewOps;
for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
Register Reg = MI.getOperand(i).getReg();
if (i % 2 == 1) {
MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
NewOps.push_back(ConstInstr);
} else {
MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
assert(BuildMBB &&
BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
BuildMBB->getOperand(1).isBlockAddress() &&
BuildMBB->getOperand(1).getBlockAddress());
NewOps.push_back(BuildMBB);
}
}
Switches.push_back(std::make_pair(&MI, NewOps));
}
}

// Update each spv_switch with destination MBBs.
for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) {
if (!isSpvIntrinsic(**i, Intrinsic::spv_switch))
continue;

// Currently considered spv_switch.
MachineInstr *Switch = *i;
// Set the first successor as default MBB to support empty switches.
MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin();
// Container for mapping values to MMBs.
SmallDenseMap<uint64_t, MachineBasicBlock *> ValuesToMBBs;

// Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered
// spv_switch (i) and break at any spv_switch with the same compare
// register (indicating we are back at the same scope).
Register CompareReg = Switch->getOperand(1).getReg();
for (auto j = i + 1; j != RelevantInsts.end(); j++) {
if (isSpvIntrinsic(**j, Intrinsic::spv_switch) &&
(*j)->getOperand(1).getReg() == CompareReg)
break;

if (!((*j)->getOpcode() == TargetOpcode::G_ICMP &&
(*j)->getOperand(2).getReg() == CompareReg))
continue;

MachineInstr *ICMP = *j;
Register Dst = ICMP->getOperand(0).getReg();
MachineOperand &PredOp = ICMP->getOperand(1);
const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
(void)CC;
assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB());
MachineBasicBlock *MBB = CBr->getOperand(1).getMBB();

// Map switch case Value to target MBB.
ValuesToMBBs[Value] = MBB;

// Add target MBB as successor to the switch's MBB.
Switch->getParent()->addSuccessor(MBB);

// The next MI is always G_BR to either the next case or the default.
MachineInstr *NextMI = CBr->getNextNode();
assert(NextMI->getOpcode() == SPIRV::G_BR &&
NextMI->getOperand(0).isMBB());
MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
// Default MBB does not begin with G_ICMP using spv_switch compare
// register.
if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
(NextMBB->front().getOperand(2).isReg() &&
NextMBB->front().getOperand(2).getReg() != CompareReg)) {
// Set default MBB and add it as successor to the switch's MBB.
DefaultMBB = NextMBB;
Switch->getParent()->addSuccessor(DefaultMBB);
SmallPtrSet<MachineInstr *, 8> ToEraseMI;
for (auto &SwIt : Switches) {
MachineInstr &MI = *SwIt.first;
SmallVector<MachineInstr *, 8> &Ins = SwIt.second;
SmallVector<MachineOperand, 8> NewOps;
for (unsigned i = 0; i < Ins.size(); ++i) {
if (Ins[i]->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
BasicBlock *CaseBB =
Ins[i]->getOperand(1).getBlockAddress()->getBasicBlock();
auto It = BB2MBB.find(CaseBB);
if (It == BB2MBB.end())
report_fatal_error("cannot find a machine basic block by a basic "
"block in a switch statement");
NewOps.push_back(MachineOperand::CreateMBB(It->second));
MI.getParent()->addSuccessor(It->second);
ToEraseMI.insert(Ins[i]);
} else {
NewOps.push_back(
MachineOperand::CreateCImm(Ins[i]->getOperand(1).getCImm()));
}
}

// Modify considered spv_switch operands using collected Values and
// MBBs.
SmallVector<const ConstantInt *, 3> Values;
SmallVector<MachineBasicBlock *, 3> MBBs;
for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) {
Register CReg = Switch->getOperand(k).getReg();
uint64_t Val = getIConstVal(CReg, &MRI);
MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
if (!ValuesToMBBs[Val])
continue;

Values.push_back(ConstInstr->getOperand(1).getCImm());
MBBs.push_back(ValuesToMBBs[Val]);
}

for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--)
Switch->removeOperand(k);

Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB));
for (unsigned k = 0; k < Values.size(); k++) {
Switch->addOperand(MachineOperand::CreateCImm(Values[k]));
Switch->addOperand(MachineOperand::CreateMBB(MBBs[k]));
}
}

for (MachineInstr *MI : PostUpdateArtifacts) {
MachineBasicBlock *ParentMBB = MI->getParent();
MI->eraseFromParent();
// If G_ICMP + G_BRCOND + G_BR were the only MIs in MBB, erase this MBB. It
// can be safely assumed, there are no breaks or phis directing into this
// MBB. However, we need to remove this MBB from the CFG graph. MBBs must be
// erased top-down.
if (ParentMBB->empty()) {
while (!ParentMBB->pred_empty())
(*ParentMBB->pred_begin())->removeSuccessor(ParentMBB);

while (!ParentMBB->succ_empty())
ParentMBB->removeSuccessor(ParentMBB->succ_begin());

ParentMBB->eraseFromParent();
for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
MI.removeOperand(i);
for (auto &MO : NewOps)
MI.addOperand(MO);
if (MachineInstr *Next = MI.getNextNode()) {
if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
ToEraseMI.insert(Next);
Next = MI.getNextNode();
}
if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
ToEraseMI.insert(Next);
}
}
for (MachineInstr *BlockAddrI : ToEraseMI)
BlockAddrI->eraseFromParent();
}

static bool isImplicitFallthrough(MachineBasicBlock &MBB) {
Expand Down
Loading