Skip to content

Commit 23b058c

Browse files
[SPIR-V] Re-implement switch and improve validation of forward calls (#87823)
This PR fixes issue #87763 and preserves valid CFG in cases when previous scheme failed to generate valid code for a switch statement. The PR hardens one existing test case and adds one more test case as a validation of a new switch generation. Tests are passing spirv-val now. This PR also improves validation of forward calls.
1 parent 3f71d29 commit 23b058c

File tree

8 files changed

+189
-185
lines changed

8 files changed

+189
-185
lines changed

llvm/lib/Target/SPIRV/SPIRVAsmPrinter.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ using namespace llvm;
4343

4444
namespace {
4545
class SPIRVAsmPrinter : public AsmPrinter {
46+
unsigned NLabels = 0;
47+
4648
public:
4749
explicit SPIRVAsmPrinter(TargetMachine &TM,
4850
std::unique_ptr<MCStreamer> Streamer)
@@ -109,10 +111,9 @@ void SPIRVAsmPrinter::emitEndOfAsmFile(Module &M) {
109111
uint32_t DecSPIRVVersion = ST->getSPIRVVersion();
110112
uint32_t Major = DecSPIRVVersion / 10;
111113
uint32_t Minor = DecSPIRVVersion - Major * 10;
112-
// TODO: calculate Bound more carefully from maximum used register number,
113-
// accounting for generated OpLabels and other related instructions if
114-
// needed.
115-
unsigned Bound = 2 * (ST->getBound() + 1);
114+
// Bound is an approximation that accounts for the maximum used register
115+
// number and number of generated OpLabels
116+
unsigned Bound = 2 * (ST->getBound() + 1) + NLabels;
116117
bool FlagToRestore = OutStreamer->getUseAssemblerInfoForParsing();
117118
OutStreamer->setUseAssemblerInfoForParsing(true);
118119
if (MCAssembler *Asm = OutStreamer->getAssemblerPtr())
@@ -158,6 +159,7 @@ void SPIRVAsmPrinter::emitOpLabel(const MachineBasicBlock &MBB) {
158159
LabelInst.setOpcode(SPIRV::OpLabel);
159160
LabelInst.addOperand(MCOperand::createReg(MAI->getOrCreateMBBRegister(MBB)));
160161
outputMCInst(LabelInst);
162+
++NLabels;
161163
}
162164

163165
void SPIRVAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,15 +460,36 @@ void SPIRVEmitIntrinsics::preprocessCompositeConstants(IRBuilder<> &B) {
460460
}
461461

462462
Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
463-
IRBuilder<> B(I.getParent());
463+
BasicBlock *ParentBB = I.getParent();
464+
IRBuilder<> B(ParentBB);
465+
B.SetInsertPoint(&I);
464466
SmallVector<Value *, 4> Args;
465-
for (auto &Op : I.operands())
466-
if (Op.get()->getType()->isSized())
467+
SmallVector<BasicBlock *> BBCases;
468+
for (auto &Op : I.operands()) {
469+
if (Op.get()->getType()->isSized()) {
467470
Args.push_back(Op);
468-
B.SetInsertPoint(&I);
469-
B.CreateIntrinsic(Intrinsic::spv_switch, {I.getOperand(0)->getType()},
470-
{Args});
471-
return &I;
471+
} else if (BasicBlock *BB = dyn_cast<BasicBlock>(Op.get())) {
472+
BBCases.push_back(BB);
473+
Args.push_back(BlockAddress::get(BB->getParent(), BB));
474+
} else {
475+
report_fatal_error("Unexpected switch operand");
476+
}
477+
}
478+
CallInst *NewI = B.CreateIntrinsic(Intrinsic::spv_switch,
479+
{I.getOperand(0)->getType()}, {Args});
480+
// remove switch to avoid its unneeded and undesirable unwrap into branches
481+
// and conditions
482+
I.replaceAllUsesWith(NewI);
483+
I.eraseFromParent();
484+
// insert artificial and temporary instruction to preserve valid CFG,
485+
// it will be removed after IR translation pass
486+
B.SetInsertPoint(ParentBB);
487+
IndirectBrInst *BrI = B.CreateIndirectBr(
488+
Constant::getNullValue(PointerType::getUnqual(ParentBB->getContext())),
489+
BBCases.size());
490+
for (BasicBlock *BBCase : BBCases)
491+
BrI->addDestination(BBCase);
492+
return BrI;
472493
}
473494

474495
Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) {

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,12 @@ class SPIRVGlobalRegistry {
284284
// Return the VReg holding the result of the given OpTypeXXX instruction.
285285
Register getSPIRVTypeID(const SPIRVType *SpirvType) const;
286286

287-
void setCurrentFunc(MachineFunction &MF) { CurMF = &MF; }
287+
// Return previous value of the current machine function
288+
MachineFunction *setCurrentFunc(MachineFunction &MF) {
289+
MachineFunction *Ret = CurMF;
290+
CurMF = &MF;
291+
return Ret;
292+
}
288293

289294
// Whether the given VReg has an OpTypeXXX instruction mapped to it with the
290295
// given opcode (e.g. OpTypeFloat).

llvm/lib/Target/SPIRV/SPIRVISelLowering.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,15 @@ void validateFunCallMachineDef(const SPIRVSubtarget &STI,
160160
: nullptr;
161161
if (DefElemType) {
162162
const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType);
163-
// Switch GR context to the call site instead of the (default) definition
164-
// side
165-
GR.setCurrentFunc(*FunCall.getParent()->getParent());
163+
// validatePtrTypes() works in the context if the call site
164+
// When we process historical records about forward calls
165+
// we need to switch context to the (forward) call site and
166+
// then restore it back to the current machine function.
167+
MachineFunction *CurMF =
168+
GR.setCurrentFunc(*FunCall.getParent()->getParent());
166169
validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType,
167170
DefElemTy);
168-
GR.setCurrentFunc(*FunDef->getParent()->getParent());
171+
GR.setCurrentFunc(*CurMF);
169172
}
170173
}
171174
}
@@ -215,6 +218,11 @@ void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI,
215218
// TODO: the logic of inserting additional bitcast's is to be moved
216219
// to pre-IRTranslation passes eventually
217220
void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
221+
// finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp)
222+
// We'd like to avoid the needless second processing pass.
223+
if (ProcessedMF.find(&MF) != ProcessedMF.end())
224+
return;
225+
218226
MachineRegisterInfo *MRI = &MF.getRegInfo();
219227
SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry();
220228
GR.setCurrentFunc(MF);
@@ -302,5 +310,6 @@ void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const {
302310
}
303311
}
304312
}
313+
ProcessedMF.insert(&MF);
305314
TargetLowering::finalizeLowering(MF);
306315
}

llvm/lib/Target/SPIRV/SPIRVISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,17 @@
1616

1717
#include "SPIRVGlobalRegistry.h"
1818
#include "llvm/CodeGen/TargetLowering.h"
19+
#include <set>
1920

2021
namespace llvm {
2122
class SPIRVSubtarget;
2223

2324
class SPIRVTargetLowering : public TargetLowering {
2425
const SPIRVSubtarget &STI;
2526

27+
// Record of already processed machine functions
28+
mutable std::set<const MachineFunction *> ProcessedMF;
29+
2630
public:
2731
explicit SPIRVTargetLowering(const TargetMachine &TM,
2832
const SPIRVSubtarget &ST)

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 56 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -438,186 +438,75 @@ static void processInstrsWithTypeFolding(MachineFunction &MF,
438438
}
439439
}
440440

441+
// Find basic blocks of the switch and replace registers in spv_switch() by its
442+
// MBB equivalent.
441443
static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
442444
MachineIRBuilder MIB) {
443-
// Before IRTranslator pass, calls to spv_switch intrinsic are inserted before
444-
// each switch instruction. IRTranslator lowers switches to G_ICMP + G_BRCOND
445-
// + G_BR triples. A switch with two cases may be transformed to this MIR
446-
// sequence:
447-
//
448-
// intrinsic(@llvm.spv.switch), %CmpReg, %Const0, %Const1
449-
// %Dst0 = G_ICMP intpred(eq), %CmpReg, %Const0
450-
// G_BRCOND %Dst0, %bb.2
451-
// G_BR %bb.5
452-
// bb.5.entry:
453-
// %Dst1 = G_ICMP intpred(eq), %CmpReg, %Const1
454-
// G_BRCOND %Dst1, %bb.3
455-
// G_BR %bb.4
456-
// bb.2.sw.bb:
457-
// ...
458-
// bb.3.sw.bb1:
459-
// ...
460-
// bb.4.sw.epilog:
461-
// ...
462-
//
463-
// Sometimes (in case of range-compare switches), additional G_SUBs
464-
// instructions are inserted before G_ICMPs. Those need to be additionally
465-
// processed.
466-
//
467-
// This function modifies spv_switch call's operands to include destination
468-
// MBBs (default and for each constant value).
469-
//
470-
// At the end, the function removes redundant [G_SUB] + G_ICMP + G_BRCOND +
471-
// G_BR sequences.
472-
473-
MachineRegisterInfo &MRI = MF.getRegInfo();
474-
475-
// Collect spv_switches and G_ICMPs across all MBBs in MF.
476-
std::vector<MachineInstr *> RelevantInsts;
477-
478-
// Collect redundant MIs from [G_SUB] + G_ICMP + G_BRCOND + G_BR sequences.
479-
// After updating spv_switches, the instructions can be removed.
480-
std::vector<MachineInstr *> PostUpdateArtifacts;
481-
482-
// Temporary set of compare registers. G_SUBs and G_ICMPs relating to
483-
// spv_switch use these registers.
484-
DenseSet<Register> CompareRegs;
445+
DenseMap<const BasicBlock *, MachineBasicBlock *> BB2MBB;
446+
SmallVector<std::pair<MachineInstr *, SmallVector<MachineInstr *, 8>>>
447+
Switches;
485448
for (MachineBasicBlock &MBB : MF) {
449+
MachineRegisterInfo &MRI = MF.getRegInfo();
450+
BB2MBB[MBB.getBasicBlock()] = &MBB;
486451
for (MachineInstr &MI : MBB) {
452+
if (!isSpvIntrinsic(MI, Intrinsic::spv_switch))
453+
continue;
487454
// Calls to spv_switch intrinsics representing IR switches.
488-
if (isSpvIntrinsic(MI, Intrinsic::spv_switch)) {
489-
assert(MI.getOperand(1).isReg());
490-
CompareRegs.insert(MI.getOperand(1).getReg());
491-
RelevantInsts.push_back(&MI);
492-
}
493-
494-
// G_SUBs coming from range-compare switch lowering. G_SUBs are found
495-
// after spv_switch but before G_ICMP.
496-
if (MI.getOpcode() == TargetOpcode::G_SUB && MI.getOperand(1).isReg() &&
497-
CompareRegs.contains(MI.getOperand(1).getReg())) {
498-
assert(MI.getOperand(0).isReg() && MI.getOperand(1).isReg());
499-
Register Dst = MI.getOperand(0).getReg();
500-
CompareRegs.insert(Dst);
501-
PostUpdateArtifacts.push_back(&MI);
502-
}
503-
504-
// G_ICMPs relating to switches.
505-
if (MI.getOpcode() == TargetOpcode::G_ICMP && MI.getOperand(2).isReg() &&
506-
CompareRegs.contains(MI.getOperand(2).getReg())) {
507-
Register Dst = MI.getOperand(0).getReg();
508-
RelevantInsts.push_back(&MI);
509-
PostUpdateArtifacts.push_back(&MI);
510-
MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
511-
assert(CBr->getOpcode() == SPIRV::G_BRCOND);
512-
PostUpdateArtifacts.push_back(CBr);
513-
MachineInstr *Br = CBr->getNextNode();
514-
assert(Br->getOpcode() == SPIRV::G_BR);
515-
PostUpdateArtifacts.push_back(Br);
455+
SmallVector<MachineInstr *, 8> NewOps;
456+
for (unsigned i = 2; i < MI.getNumOperands(); ++i) {
457+
Register Reg = MI.getOperand(i).getReg();
458+
if (i % 2 == 1) {
459+
MachineInstr *ConstInstr = getDefInstrMaybeConstant(Reg, &MRI);
460+
NewOps.push_back(ConstInstr);
461+
} else {
462+
MachineInstr *BuildMBB = MRI.getVRegDef(Reg);
463+
assert(BuildMBB &&
464+
BuildMBB->getOpcode() == TargetOpcode::G_BLOCK_ADDR &&
465+
BuildMBB->getOperand(1).isBlockAddress() &&
466+
BuildMBB->getOperand(1).getBlockAddress());
467+
NewOps.push_back(BuildMBB);
468+
}
516469
}
470+
Switches.push_back(std::make_pair(&MI, NewOps));
517471
}
518472
}
519473

520-
// Update each spv_switch with destination MBBs.
521-
for (auto i = RelevantInsts.begin(); i != RelevantInsts.end(); i++) {
522-
if (!isSpvIntrinsic(**i, Intrinsic::spv_switch))
523-
continue;
524-
525-
// Currently considered spv_switch.
526-
MachineInstr *Switch = *i;
527-
// Set the first successor as default MBB to support empty switches.
528-
MachineBasicBlock *DefaultMBB = *Switch->getParent()->succ_begin();
529-
// Container for mapping values to MMBs.
530-
SmallDenseMap<uint64_t, MachineBasicBlock *> ValuesToMBBs;
531-
532-
// Walk all G_ICMPs to collect ValuesToMBBs. Start at currently considered
533-
// spv_switch (i) and break at any spv_switch with the same compare
534-
// register (indicating we are back at the same scope).
535-
Register CompareReg = Switch->getOperand(1).getReg();
536-
for (auto j = i + 1; j != RelevantInsts.end(); j++) {
537-
if (isSpvIntrinsic(**j, Intrinsic::spv_switch) &&
538-
(*j)->getOperand(1).getReg() == CompareReg)
539-
break;
540-
541-
if (!((*j)->getOpcode() == TargetOpcode::G_ICMP &&
542-
(*j)->getOperand(2).getReg() == CompareReg))
543-
continue;
544-
545-
MachineInstr *ICMP = *j;
546-
Register Dst = ICMP->getOperand(0).getReg();
547-
MachineOperand &PredOp = ICMP->getOperand(1);
548-
const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
549-
(void)CC;
550-
assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
551-
MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
552-
uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
553-
MachineInstr *CBr = MRI.use_begin(Dst)->getParent();
554-
assert(CBr->getOpcode() == SPIRV::G_BRCOND && CBr->getOperand(1).isMBB());
555-
MachineBasicBlock *MBB = CBr->getOperand(1).getMBB();
556-
557-
// Map switch case Value to target MBB.
558-
ValuesToMBBs[Value] = MBB;
559-
560-
// Add target MBB as successor to the switch's MBB.
561-
Switch->getParent()->addSuccessor(MBB);
562-
563-
// The next MI is always G_BR to either the next case or the default.
564-
MachineInstr *NextMI = CBr->getNextNode();
565-
assert(NextMI->getOpcode() == SPIRV::G_BR &&
566-
NextMI->getOperand(0).isMBB());
567-
MachineBasicBlock *NextMBB = NextMI->getOperand(0).getMBB();
568-
// Default MBB does not begin with G_ICMP using spv_switch compare
569-
// register.
570-
if (NextMBB->front().getOpcode() != SPIRV::G_ICMP ||
571-
(NextMBB->front().getOperand(2).isReg() &&
572-
NextMBB->front().getOperand(2).getReg() != CompareReg)) {
573-
// Set default MBB and add it as successor to the switch's MBB.
574-
DefaultMBB = NextMBB;
575-
Switch->getParent()->addSuccessor(DefaultMBB);
474+
SmallPtrSet<MachineInstr *, 8> ToEraseMI;
475+
for (auto &SwIt : Switches) {
476+
MachineInstr &MI = *SwIt.first;
477+
SmallVector<MachineInstr *, 8> &Ins = SwIt.second;
478+
SmallVector<MachineOperand, 8> NewOps;
479+
for (unsigned i = 0; i < Ins.size(); ++i) {
480+
if (Ins[i]->getOpcode() == TargetOpcode::G_BLOCK_ADDR) {
481+
BasicBlock *CaseBB =
482+
Ins[i]->getOperand(1).getBlockAddress()->getBasicBlock();
483+
auto It = BB2MBB.find(CaseBB);
484+
if (It == BB2MBB.end())
485+
report_fatal_error("cannot find a machine basic block by a basic "
486+
"block in a switch statement");
487+
NewOps.push_back(MachineOperand::CreateMBB(It->second));
488+
MI.getParent()->addSuccessor(It->second);
489+
ToEraseMI.insert(Ins[i]);
490+
} else {
491+
NewOps.push_back(
492+
MachineOperand::CreateCImm(Ins[i]->getOperand(1).getCImm()));
576493
}
577494
}
578-
579-
// Modify considered spv_switch operands using collected Values and
580-
// MBBs.
581-
SmallVector<const ConstantInt *, 3> Values;
582-
SmallVector<MachineBasicBlock *, 3> MBBs;
583-
for (unsigned k = 2; k < Switch->getNumExplicitOperands(); k++) {
584-
Register CReg = Switch->getOperand(k).getReg();
585-
uint64_t Val = getIConstVal(CReg, &MRI);
586-
MachineInstr *ConstInstr = getDefInstrMaybeConstant(CReg, &MRI);
587-
if (!ValuesToMBBs[Val])
588-
continue;
589-
590-
Values.push_back(ConstInstr->getOperand(1).getCImm());
591-
MBBs.push_back(ValuesToMBBs[Val]);
592-
}
593-
594-
for (unsigned k = Switch->getNumExplicitOperands() - 1; k > 1; k--)
595-
Switch->removeOperand(k);
596-
597-
Switch->addOperand(MachineOperand::CreateMBB(DefaultMBB));
598-
for (unsigned k = 0; k < Values.size(); k++) {
599-
Switch->addOperand(MachineOperand::CreateCImm(Values[k]));
600-
Switch->addOperand(MachineOperand::CreateMBB(MBBs[k]));
601-
}
602-
}
603-
604-
for (MachineInstr *MI : PostUpdateArtifacts) {
605-
MachineBasicBlock *ParentMBB = MI->getParent();
606-
MI->eraseFromParent();
607-
// If G_ICMP + G_BRCOND + G_BR were the only MIs in MBB, erase this MBB. It
608-
// can be safely assumed, there are no breaks or phis directing into this
609-
// MBB. However, we need to remove this MBB from the CFG graph. MBBs must be
610-
// erased top-down.
611-
if (ParentMBB->empty()) {
612-
while (!ParentMBB->pred_empty())
613-
(*ParentMBB->pred_begin())->removeSuccessor(ParentMBB);
614-
615-
while (!ParentMBB->succ_empty())
616-
ParentMBB->removeSuccessor(ParentMBB->succ_begin());
617-
618-
ParentMBB->eraseFromParent();
495+
for (unsigned i = MI.getNumOperands() - 1; i > 1; --i)
496+
MI.removeOperand(i);
497+
for (auto &MO : NewOps)
498+
MI.addOperand(MO);
499+
if (MachineInstr *Next = MI.getNextNode()) {
500+
if (isSpvIntrinsic(*Next, Intrinsic::spv_track_constant)) {
501+
ToEraseMI.insert(Next);
502+
Next = MI.getNextNode();
503+
}
504+
if (Next && Next->getOpcode() == TargetOpcode::G_BRINDIRECT)
505+
ToEraseMI.insert(Next);
619506
}
620507
}
508+
for (MachineInstr *BlockAddrI : ToEraseMI)
509+
BlockAddrI->eraseFromParent();
621510
}
622511

623512
static bool isImplicitFallthrough(MachineBasicBlock &MBB) {

0 commit comments

Comments
 (0)