Skip to content

Commit fe6893b

Browse files
authored
Improve selection of conditional branch on amdgcn.ballot!=0 condition in SelectionDAG. (#68714)
Improve selection of the following pattern: bool cnd = ... if (amdgcn.ballot(cnd) != 0) { ... } which means "execute _then_ if any lane has satisfied the _cnd_ condition".
1 parent feedb7c commit fe6893b

File tree

7 files changed

+1719
-5
lines changed

7 files changed

+1719
-5
lines changed

llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "MCTargetDesc/AMDGPUMCTargetDesc.h"
2020
#include "MCTargetDesc/R600MCTargetDesc.h"
2121
#include "R600RegisterInfo.h"
22+
#include "SIISelLowering.h"
2223
#include "SIMachineFunctionInfo.h"
2324
#include "llvm/Analysis/UniformityAnalysis.h"
2425
#include "llvm/Analysis/ValueTracking.h"
@@ -2263,6 +2264,34 @@ bool AMDGPUDAGToDAGISel::isCBranchSCC(const SDNode *N) const {
22632264
return false;
22642265
}
22652266

2267+
static SDValue combineBallotPattern(SDValue VCMP, bool &Negate) {
2268+
assert(VCMP->getOpcode() == AMDGPUISD::SETCC);
2269+
// Special case for amdgcn.ballot:
2270+
// %Cond = i1 (and/or combination of i1 ISD::SETCCs)
2271+
// %VCMP = i(WaveSize) AMDGPUISD::SETCC (ext %Cond), 0, setne/seteq
2272+
// =>
2273+
// Use i1 %Cond value instead of i(WaveSize) %VCMP.
2274+
// This is possible because divergent ISD::SETCC is selected as V_CMP and
2275+
// Cond becomes a i(WaveSize) full mask value.
2276+
// Note that ballot doesn't use SETEQ condition but its easy to support it
2277+
// here for completeness, so in this case Negate is set true on return.
2278+
auto VCMP_CC = cast<CondCodeSDNode>(VCMP.getOperand(2))->get();
2279+
auto *VCMP_CRHS = dyn_cast<ConstantSDNode>(VCMP.getOperand(1));
2280+
if ((VCMP_CC == ISD::SETEQ || VCMP_CC == ISD::SETNE) && VCMP_CRHS &&
2281+
VCMP_CRHS->isZero()) {
2282+
2283+
auto Cond = VCMP.getOperand(0);
2284+
if (ISD::isExtOpcode(Cond->getOpcode())) // Skip extension.
2285+
Cond = Cond.getOperand(0);
2286+
2287+
if (isBoolSGPR(Cond)) {
2288+
Negate = VCMP_CC == ISD::SETEQ;
2289+
return Cond;
2290+
}
2291+
}
2292+
return SDValue();
2293+
}
2294+
22662295
void AMDGPUDAGToDAGISel::SelectBRCOND(SDNode *N) {
22672296
SDValue Cond = N->getOperand(1);
22682297

@@ -2276,11 +2305,50 @@ void AMDGPUDAGToDAGISel::SelectBRCOND(SDNode *N) {
22762305
const SIRegisterInfo *TRI = ST->getRegisterInfo();
22772306

22782307
bool UseSCCBr = isCBranchSCC(N) && isUniformBr(N);
2279-
unsigned BrOp = UseSCCBr ? AMDGPU::S_CBRANCH_SCC1 : AMDGPU::S_CBRANCH_VCCNZ;
2308+
bool AndExec = !UseSCCBr;
2309+
bool Negate = false;
2310+
2311+
if (Cond.getOpcode() == ISD::SETCC &&
2312+
Cond->getOperand(0)->getOpcode() == AMDGPUISD::SETCC) {
2313+
SDValue VCMP = Cond->getOperand(0);
2314+
auto CC = cast<CondCodeSDNode>(Cond->getOperand(2))->get();
2315+
auto *CRHS = dyn_cast<ConstantSDNode>(Cond->getOperand(1));
2316+
if ((CC == ISD::SETEQ || CC == ISD::SETNE) && CRHS && CRHS->isZero() &&
2317+
// TODO: make condition below an assert after fixing ballot bitwidth.
2318+
VCMP.getValueType().getSizeInBits() == ST->getWavefrontSize()) {
2319+
// %VCMP = i(WaveSize) AMDGPUISD::SETCC ...
2320+
// %C = i1 ISD::SETCC %VCMP, 0, setne/seteq
2321+
// BRCOND i1 %C, %BB
2322+
// =>
2323+
// %VCMP = i(WaveSize) AMDGPUISD::SETCC ...
2324+
// VCC = COPY i(WaveSize) %VCMP
2325+
// S_CBRANCH_VCCNZ/VCCZ %BB
2326+
Negate = CC == ISD::SETEQ;
2327+
bool NegatedBallot = false;
2328+
if (auto BallotCond = combineBallotPattern(VCMP, NegatedBallot)) {
2329+
Cond = BallotCond;
2330+
UseSCCBr = !BallotCond->isDivergent();
2331+
Negate = Negate ^ NegatedBallot;
2332+
} else {
2333+
// TODO: don't use SCC here assuming that AMDGPUISD::SETCC is always
2334+
// selected as V_CMP, but this may change for uniform condition.
2335+
Cond = VCMP;
2336+
UseSCCBr = false;
2337+
}
2338+
}
2339+
// Cond is either V_CMP resulted from AMDGPUISD::SETCC or a combination of
2340+
// V_CMPs resulted from ballot or ballot has uniform condition and SCC is
2341+
// used.
2342+
AndExec = false;
2343+
}
2344+
2345+
unsigned BrOp =
2346+
UseSCCBr ? (Negate ? AMDGPU::S_CBRANCH_SCC0 : AMDGPU::S_CBRANCH_SCC1)
2347+
: (Negate ? AMDGPU::S_CBRANCH_VCCZ : AMDGPU::S_CBRANCH_VCCNZ);
22802348
Register CondReg = UseSCCBr ? AMDGPU::SCC : TRI->getVCC();
22812349
SDLoc SL(N);
22822350

2283-
if (!UseSCCBr) {
2351+
if (AndExec) {
22842352
// This is the case that we are selecting to S_CBRANCH_VCCNZ. We have not
22852353
// analyzed what generates the vcc value, so we do not know whether vcc
22862354
// bits for disabled lanes are 0. Thus we need to mask out bits for

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10628,9 +10628,7 @@ SDValue SITargetLowering::splitBinaryBitConstantOp(
1062810628
return SDValue();
1062910629
}
1063010630

10631-
// Returns true if argument is a boolean value which is not serialized into
10632-
// memory or argument and does not require v_cndmask_b32 to be deserialized.
10633-
static bool isBoolSGPR(SDValue V) {
10631+
bool llvm::isBoolSGPR(SDValue V) {
1063410632
if (V.getValueType() != MVT::i1)
1063510633
return false;
1063610634
switch (V.getOpcode()) {

llvm/lib/Target/AMDGPU/SIISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,10 @@ class SITargetLowering final : public AMDGPUTargetLowering {
586586
getTargetMMOFlags(const Instruction &I) const override;
587587
};
588588

589+
// Returns true if argument is a boolean value which is not serialized into
590+
// memory or argument and does not require v_cndmask_b32 to be deserialized.
591+
bool isBoolSGPR(SDValue V);
592+
589593
} // End namespace llvm
590594

591595
#endif

0 commit comments

Comments
 (0)