Skip to content

[SelectionDAG][X86] Preserve unpredictable metadata for conditional branches in SelectionDAG, as well as JCCs generated by X86 backend. #102101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -1165,8 +1165,13 @@ class SelectionDAG {
SDValue N2, SDValue N3, const SDNodeFlags Flags);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
SDValue N2, SDValue N3, SDValue N4);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
SDValue N2, SDValue N3, SDValue N4, const SDNodeFlags Flags);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
SDValue N2, SDValue N3, SDValue N4, SDValue N5);
SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1,
SDValue N2, SDValue N3, SDValue N4, SDValue N5,
const SDNodeFlags Flags);

// Specialize again based on number of operands for nodes with a VTList
// rather than a single VT.
Expand Down
15 changes: 10 additions & 5 deletions llvm/include/llvm/CodeGen/SwitchLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,29 +137,34 @@ struct CaseBlock {
SDLoc DL;
DebugLoc DbgLoc;

// Branch weights.
// Branch weights and predictability.
BranchProbability TrueProb, FalseProb;
bool IsUnpredictable;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where this get used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind. I found 09515f2 in your description.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See SelectionDAGBuilder::visitSwitchCase.


// Constructor for SelectionDAG.
CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
const Value *cmpmiddle, MachineBasicBlock *truebb,
MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
BranchProbability trueprob = BranchProbability::getUnknown(),
BranchProbability falseprob = BranchProbability::getUnknown())
BranchProbability falseprob = BranchProbability::getUnknown(),
bool isunpredictable = false)
: CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
TrueProb(trueprob), FalseProb(falseprob) {}
TrueProb(trueprob), FalseProb(falseprob),
IsUnpredictable(isunpredictable) {}

// Constructor for GISel.
CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
const Value *cmprhs, const Value *cmpmiddle,
MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
MachineBasicBlock *me, DebugLoc dl,
BranchProbability trueprob = BranchProbability::getUnknown(),
BranchProbability falseprob = BranchProbability::getUnknown())
BranchProbability falseprob = BranchProbability::getUnknown(),
bool isunpredictable = false)
: PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob),
IsUnpredictable(isunpredictable) {}
};

struct JumpTable {
Expand Down
7 changes: 4 additions & 3 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18164,7 +18164,7 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
// nondeterministic jumps).
if (N1->getOpcode() == ISD::FREEZE && N1.hasOneUse()) {
return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
N1->getOperand(0), N2);
N1->getOperand(0), N2, N->getFlags());
}

// Variant of the previous fold where there is a SETCC in between:
Expand Down Expand Up @@ -18213,7 +18213,8 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
if (Updated)
return DAG.getNode(
ISD::BRCOND, SDLoc(N), MVT::Other, Chain,
DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2);
DAG.getSetCC(SDLoc(N1), N1->getValueType(0), S0, S1, Cond), N2,
N->getFlags());
}

// If N is a constant we could fold this into a fallthrough or unconditional
Expand All @@ -18238,7 +18239,7 @@ SDValue DAGCombiner::visitBRCOND(SDNode *N) {
HandleSDNode ChainHandle(Chain);
if (SDValue NewN1 = rebuildSetCC(N1))
return DAG.getNode(ISD::BRCOND, SDLoc(N), MVT::Other,
ChainHandle.getValue(), NewN1, N2);
ChainHandle.getValue(), NewN1, N2, N->getFlags());
}

return SDValue();
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1065,14 +1065,17 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned,
// Create the new machine instruction.
MachineInstrBuilder MIB = BuildMI(*MF, Node->getDebugLoc(), II);

// Transfer IR flags from the SDNode to the MachineInstr
MachineInstr *MI = MIB.getInstr();
const SDNodeFlags Flags = Node->getFlags();
if (Flags.hasUnpredictable())
MI->setFlag(MachineInstr::MIFlag::Unpredictable);

// Add result register values for things that are defined by this
// instruction.
if (NumResults) {
CreateVirtualRegisters(Node, MIB, II, IsClone, IsCloned, VRBaseMap);

// Transfer any IR flags from the SDNode to the MachineInstr
MachineInstr *MI = MIB.getInstr();
const SDNodeFlags Flags = Node->getFlags();
if (Flags.hasNoSignedZeros())
MI->setFlag(MachineInstr::MIFlag::FmNsz);

Expand Down Expand Up @@ -1105,9 +1108,6 @@ EmitMachineNode(SDNode *Node, bool IsClone, bool IsCloned,

if (Flags.hasNoFPExcept())
MI->setFlag(MachineInstr::MIFlag::NoFPExcept);

if (Flags.hasUnpredictable())
MI->setFlag(MachineInstr::MIFlag::Unpredictable);
}

// Emit all of the actual operands of this instruction, adding them to the
Expand Down
26 changes: 22 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7613,16 +7613,34 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
}

SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue N1, SDValue N2, SDValue N3, SDValue N4) {
SDValue N1, SDValue N2, SDValue N3, SDValue N4,
const SDNodeFlags Flags) {
SDValue Ops[] = { N1, N2, N3, N4 };
return getNode(Opcode, DL, VT, Ops);
return getNode(Opcode, DL, VT, Ops, Flags);
}

SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue N1, SDValue N2, SDValue N3, SDValue N4) {
SDNodeFlags Flags;
if (Inserter)
Flags = Inserter->getFlags();
return getNode(Opcode, DL, VT, N1, N2, N3, N4, Flags);
}

SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue N1, SDValue N2, SDValue N3, SDValue N4,
SDValue N5) {
SDValue N5, const SDNodeFlags Flags) {
SDValue Ops[] = { N1, N2, N3, N4, N5 };
return getNode(Opcode, DL, VT, Ops);
return getNode(Opcode, DL, VT, Ops, Flags);
}

SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
SDValue N1, SDValue N2, SDValue N3, SDValue N4,
SDValue N5) {
SDNodeFlags Flags;
if (Inserter)
Flags = Inserter->getFlags();
return getNode(Opcode, DL, VT, N1, N2, N3, N4, N5, Flags);
}

/// getStackArgumentTokenFactor - Compute a TokenFactor to force all
Expand Down
14 changes: 9 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2818,9 +2818,10 @@ void SelectionDAGBuilder::visitBr(const BranchInst &I) {
// je foo
// cmp D, E
// jle foo
bool IsUnpredictable = I.hasMetadata(LLVMContext::MD_unpredictable);
const Instruction *BOp = dyn_cast<Instruction>(CondVal);
if (!DAG.getTargetLoweringInfo().isJumpExpensive() && BOp &&
BOp->hasOneUse() && !I.hasMetadata(LLVMContext::MD_unpredictable)) {
BOp->hasOneUse() && !IsUnpredictable) {
Value *Vec;
const Value *BOp0, *BOp1;
Instruction::BinaryOps Opcode = (Instruction::BinaryOps)0;
Expand Down Expand Up @@ -2869,7 +2870,9 @@ void SelectionDAGBuilder::visitBr(const BranchInst &I) {

// Create a CaseBlock record representing this branch.
CaseBlock CB(ISD::SETEQ, CondVal, ConstantInt::getTrue(*DAG.getContext()),
nullptr, Succ0MBB, Succ1MBB, BrMBB, getCurSDLoc());
nullptr, Succ0MBB, Succ1MBB, BrMBB, getCurSDLoc(),
BranchProbability::getUnknown(), BranchProbability::getUnknown(),
IsUnpredictable);

// Use visitSwitchCase to actually insert the fast branch sequence for this
// cond branch.
Expand Down Expand Up @@ -2957,9 +2960,10 @@ void SelectionDAGBuilder::visitSwitchCase(CaseBlock &CB,
Cond = DAG.getNode(ISD::XOR, dl, Cond.getValueType(), Cond, True);
}

SDValue BrCond = DAG.getNode(ISD::BRCOND, dl,
MVT::Other, getControlRoot(), Cond,
DAG.getBasicBlock(CB.TrueBB));
SDNodeFlags Flags;
Flags.setUnpredictable(CB.IsUnpredictable);
SDValue BrCond = DAG.getNode(ISD::BRCOND, dl, MVT::Other, getControlRoot(),
Cond, DAG.getBasicBlock(CB.TrueBB), Flags);

setValue(CurInst, BrCond);

Expand Down
22 changes: 11 additions & 11 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24827,14 +24827,14 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {

SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
Overflow);
Overflow, Op->getFlags());
}

if (LHS.getSimpleValueType().isInteger()) {
SDValue CCVal;
SDValue EFLAGS = emitFlagsForSetcc(LHS, RHS, CC, SDLoc(Cond), DAG, CCVal);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
EFLAGS);
EFLAGS, Op->getFlags());
}

if (CC == ISD::SETOEQ) {
Expand All @@ -24860,10 +24860,10 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
SDValue CCVal = DAG.getTargetConstant(X86::COND_NE, dl, MVT::i8);
Chain = DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest,
CCVal, Cmp);
CCVal, Cmp, Op->getFlags());
CCVal = DAG.getTargetConstant(X86::COND_P, dl, MVT::i8);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
Cmp);
Cmp, Op->getFlags());
}
}
} else if (CC == ISD::SETUNE) {
Expand All @@ -24872,18 +24872,18 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {
// separate test.
SDValue Cmp = DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
SDValue CCVal = DAG.getTargetConstant(X86::COND_NE, dl, MVT::i8);
Chain =
DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal, Cmp);
Chain = DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
Cmp, Op->getFlags());
CCVal = DAG.getTargetConstant(X86::COND_P, dl, MVT::i8);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
Cmp);
Cmp, Op->getFlags());
} else {
X86::CondCode X86Cond =
TranslateX86CC(CC, dl, /*IsFP*/ true, LHS, RHS, DAG);
SDValue Cmp = DAG.getNode(X86ISD::FCMP, SDLoc(Cond), MVT::i32, LHS, RHS);
SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
Cmp);
Cmp, Op->getFlags());
}
}

Expand All @@ -24894,7 +24894,7 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {

SDValue CCVal = DAG.getTargetConstant(X86Cond, dl, MVT::i8);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
Overflow);
Overflow, Op->getFlags());
}

// Look past the truncate if the high bits are known zero.
Expand All @@ -24913,8 +24913,8 @@ SDValue X86TargetLowering::LowerBRCOND(SDValue Op, SelectionDAG &DAG) const {

SDValue CCVal;
SDValue EFLAGS = emitFlagsForSetcc(LHS, RHS, ISD::SETNE, dl, DAG, CCVal);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
EFLAGS);
return DAG.getNode(X86ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal, EFLAGS,
Op->getFlags());
}

// Lower dynamic stack allocation to _alloca call for Cygwin/Mingw targets.
Expand Down
10 changes: 5 additions & 5 deletions llvm/test/CodeGen/X86/unpredictable-brcond.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-commit the test case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. Submitted #102262.

; Currently, unpredictable metadata on conditional branches is lost during CodeGen.
; Make sure MIR generated for conditional branch with unpredictable metadata has unpredictable flag.
; RUN: llc -mtriple=x86_64-unknown-linux-gnu -stop-after=finalize-isel < %s | FileCheck %s

define void @cond_branch_1(i1 %cond) {
Expand All @@ -11,7 +11,7 @@ define void @cond_branch_1(i1 %cond) {
; CHECK-NEXT: [[COPY:%[0-9]+]]:gr32 = COPY $edi
; CHECK-NEXT: [[COPY1:%[0-9]+]]:gr8 = COPY [[COPY]].sub_8bit
; CHECK-NEXT: TEST8ri killed [[COPY1]], 1, implicit-def $eflags
; CHECK-NEXT: JCC_1 %bb.2, 4, implicit $eflags
; CHECK-NEXT: unpredictable JCC_1 %bb.2, 4, implicit $eflags
; CHECK-NEXT: JMP_1 %bb.1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: bb.1.then:
Expand Down Expand Up @@ -51,7 +51,7 @@ define void @cond_branch_2(double %a, double %b, i32 %c, i32 %d) nounwind {
; CHECK-NEXT: [[SETCCr1:%[0-9]+]]:gr8 = SETCCr 6, implicit $eflags
; CHECK-NEXT: [[OR8rr:%[0-9]+]]:gr8 = OR8rr [[SETCCr]], killed [[SETCCr1]], implicit-def dead $eflags
; CHECK-NEXT: TEST8rr [[OR8rr]], [[OR8rr]], implicit-def $eflags
; CHECK-NEXT: JCC_1 %bb.2, 5, implicit $eflags
; CHECK-NEXT: unpredictable JCC_1 %bb.2, 5, implicit $eflags
; CHECK-NEXT: JMP_1 %bb.1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: bb.1.true:
Expand Down Expand Up @@ -89,8 +89,8 @@ define void @isint_branch(double %d) nounwind {
; CHECK-NEXT: [[CVTDQ2PDrr:%[0-9]+]]:vr128 = CVTDQ2PDrr killed [[CVTTPD2DQrr]]
; CHECK-NEXT: [[COPY2:%[0-9]+]]:fr64 = COPY [[CVTDQ2PDrr]]
; CHECK-NEXT: nofpexcept UCOMISDrr [[COPY]], killed [[COPY2]], implicit-def $eflags, implicit $mxcsr
; CHECK-NEXT: JCC_1 %bb.2, 5, implicit $eflags
; CHECK-NEXT: JCC_1 %bb.2, 10, implicit $eflags
; CHECK-NEXT: unpredictable JCC_1 %bb.2, 5, implicit $eflags
; CHECK-NEXT: unpredictable JCC_1 %bb.2, 10, implicit $eflags
; CHECK-NEXT: JMP_1 %bb.1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: bb.1.true:
Expand Down
Loading