Skip to content

Commit 1687555

Browse files
author
Thorsten Schütt
authored
[GlobalIsel] Combine select of binops (#76763)
1 parent 5b33cff commit 1687555

File tree

4 files changed

+322
-28
lines changed

4 files changed

+322
-28
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,9 @@ class CombinerHelper {
910910

911911
bool tryFoldSelectOfConstants(GSelect *Select, BuildFnTy &MatchInfo);
912912

913+
/// Try to fold select(cc, binop(), binop()) -> binop(select(), X)
914+
bool tryFoldSelectOfBinOps(GSelect *Select, BuildFnTy &MatchInfo);
915+
913916
bool isOneOrOneSplat(Register Src, bool AllowUndefs);
914917
bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
915918
bool isConstantSplatVector(Register Src, int64_t SplatValue,

llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,109 @@ class GVecReduce : public GenericMachineInstr {
558558
}
559559
};
560560

561+
// Represents a binary operation, i.e, x = y op z.
562+
class GBinOp : public GenericMachineInstr {
563+
public:
564+
Register getLHSReg() const { return getReg(1); }
565+
Register getRHSReg() const { return getReg(2); }
566+
567+
static bool classof(const MachineInstr *MI) {
568+
switch (MI->getOpcode()) {
569+
// Integer.
570+
case TargetOpcode::G_ADD:
571+
case TargetOpcode::G_SUB:
572+
case TargetOpcode::G_MUL:
573+
case TargetOpcode::G_SDIV:
574+
case TargetOpcode::G_UDIV:
575+
case TargetOpcode::G_SREM:
576+
case TargetOpcode::G_UREM:
577+
case TargetOpcode::G_SMIN:
578+
case TargetOpcode::G_SMAX:
579+
case TargetOpcode::G_UMIN:
580+
case TargetOpcode::G_UMAX:
581+
// Floating point.
582+
case TargetOpcode::G_FMINNUM:
583+
case TargetOpcode::G_FMAXNUM:
584+
case TargetOpcode::G_FMINNUM_IEEE:
585+
case TargetOpcode::G_FMAXNUM_IEEE:
586+
case TargetOpcode::G_FMINIMUM:
587+
case TargetOpcode::G_FMAXIMUM:
588+
case TargetOpcode::G_FADD:
589+
case TargetOpcode::G_FSUB:
590+
case TargetOpcode::G_FMUL:
591+
case TargetOpcode::G_FDIV:
592+
case TargetOpcode::G_FPOW:
593+
// Logical.
594+
case TargetOpcode::G_AND:
595+
case TargetOpcode::G_OR:
596+
case TargetOpcode::G_XOR:
597+
return true;
598+
default:
599+
return false;
600+
}
601+
};
602+
};
603+
604+
// Represents an integer binary operation.
605+
class GIntBinOp : public GBinOp {
606+
public:
607+
static bool classof(const MachineInstr *MI) {
608+
switch (MI->getOpcode()) {
609+
case TargetOpcode::G_ADD:
610+
case TargetOpcode::G_SUB:
611+
case TargetOpcode::G_MUL:
612+
case TargetOpcode::G_SDIV:
613+
case TargetOpcode::G_UDIV:
614+
case TargetOpcode::G_SREM:
615+
case TargetOpcode::G_UREM:
616+
case TargetOpcode::G_SMIN:
617+
case TargetOpcode::G_SMAX:
618+
case TargetOpcode::G_UMIN:
619+
case TargetOpcode::G_UMAX:
620+
return true;
621+
default:
622+
return false;
623+
}
624+
};
625+
};
626+
627+
// Represents a floating point binary operation.
628+
class GFBinOp : public GBinOp {
629+
public:
630+
static bool classof(const MachineInstr *MI) {
631+
switch (MI->getOpcode()) {
632+
case TargetOpcode::G_FMINNUM:
633+
case TargetOpcode::G_FMAXNUM:
634+
case TargetOpcode::G_FMINNUM_IEEE:
635+
case TargetOpcode::G_FMAXNUM_IEEE:
636+
case TargetOpcode::G_FMINIMUM:
637+
case TargetOpcode::G_FMAXIMUM:
638+
case TargetOpcode::G_FADD:
639+
case TargetOpcode::G_FSUB:
640+
case TargetOpcode::G_FMUL:
641+
case TargetOpcode::G_FDIV:
642+
case TargetOpcode::G_FPOW:
643+
return true;
644+
default:
645+
return false;
646+
}
647+
};
648+
};
649+
650+
// Represents a logical binary operation.
651+
class GLogicalBinOp : public GBinOp {
652+
public:
653+
static bool classof(const MachineInstr *MI) {
654+
switch (MI->getOpcode()) {
655+
case TargetOpcode::G_AND:
656+
case TargetOpcode::G_OR:
657+
case TargetOpcode::G_XOR:
658+
return true;
659+
default:
660+
return false;
661+
}
662+
};
663+
};
561664

562665
} // namespace llvm
563666

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6390,8 +6390,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
63906390
if (TrueValue.isZero() && FalseValue.isOne()) {
63916391
MatchInfo = [=](MachineIRBuilder &B) {
63926392
B.setInstrAndDebugLoc(*Select);
6393-
Register Inner = MRI.createGenericVirtualRegister(CondTy);
6394-
B.buildNot(Inner, Cond);
6393+
auto Inner = B.buildNot(CondTy, Cond);
63956394
B.buildZExtOrTrunc(Dest, Inner);
63966395
};
63976396
return true;
@@ -6401,8 +6400,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
64016400
if (TrueValue.isZero() && FalseValue.isAllOnes()) {
64026401
MatchInfo = [=](MachineIRBuilder &B) {
64036402
B.setInstrAndDebugLoc(*Select);
6404-
Register Inner = MRI.createGenericVirtualRegister(CondTy);
6405-
B.buildNot(Inner, Cond);
6403+
auto Inner = B.buildNot(CondTy, Cond);
64066404
B.buildSExtOrTrunc(Dest, Inner);
64076405
};
64086406
return true;
@@ -6412,8 +6410,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
64126410
if (TrueValue - 1 == FalseValue) {
64136411
MatchInfo = [=](MachineIRBuilder &B) {
64146412
B.setInstrAndDebugLoc(*Select);
6415-
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6416-
B.buildZExtOrTrunc(Inner, Cond);
6413+
auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
64176414
B.buildAdd(Dest, Inner, False);
64186415
};
64196416
return true;
@@ -6423,8 +6420,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
64236420
if (TrueValue + 1 == FalseValue) {
64246421
MatchInfo = [=](MachineIRBuilder &B) {
64256422
B.setInstrAndDebugLoc(*Select);
6426-
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6427-
B.buildSExtOrTrunc(Inner, Cond);
6423+
auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
64286424
B.buildAdd(Dest, Inner, False);
64296425
};
64306426
return true;
@@ -6434,8 +6430,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
64346430
if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
64356431
MatchInfo = [=](MachineIRBuilder &B) {
64366432
B.setInstrAndDebugLoc(*Select);
6437-
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6438-
B.buildZExtOrTrunc(Inner, Cond);
6433+
auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
64396434
// The shift amount must be scalar.
64406435
LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
64416436
auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
@@ -6447,8 +6442,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
64476442
if (TrueValue.isAllOnes()) {
64486443
MatchInfo = [=](MachineIRBuilder &B) {
64496444
B.setInstrAndDebugLoc(*Select);
6450-
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6451-
B.buildSExtOrTrunc(Inner, Cond);
6445+
auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
64526446
B.buildOr(Dest, Inner, False, Flags);
64536447
};
64546448
return true;
@@ -6458,10 +6452,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
64586452
if (FalseValue.isAllOnes()) {
64596453
MatchInfo = [=](MachineIRBuilder &B) {
64606454
B.setInstrAndDebugLoc(*Select);
6461-
Register Not = MRI.createGenericVirtualRegister(CondTy);
6462-
B.buildNot(Not, Cond);
6463-
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
6464-
B.buildSExtOrTrunc(Inner, Not);
6455+
auto Not = B.buildNot(CondTy, Cond);
6456+
auto Inner = B.buildSExtOrTrunc(TrueTy, Not);
64656457
B.buildOr(Dest, Inner, True, Flags);
64666458
};
64676459
return true;
@@ -6496,8 +6488,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
64966488
if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
64976489
MatchInfo = [=](MachineIRBuilder &B) {
64986490
B.setInstrAndDebugLoc(*Select);
6499-
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
6500-
B.buildZExtOrTrunc(Ext, Cond);
6491+
auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
65016492
B.buildOr(DstReg, Ext, False, Flags);
65026493
};
65036494
return true;
@@ -6508,8 +6499,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
65086499
if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
65096500
MatchInfo = [=](MachineIRBuilder &B) {
65106501
B.setInstrAndDebugLoc(*Select);
6511-
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
6512-
B.buildZExtOrTrunc(Ext, Cond);
6502+
auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
65136503
B.buildAnd(DstReg, Ext, True);
65146504
};
65156505
return true;
@@ -6520,11 +6510,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
65206510
MatchInfo = [=](MachineIRBuilder &B) {
65216511
B.setInstrAndDebugLoc(*Select);
65226512
// First the not.
6523-
Register Inner = MRI.createGenericVirtualRegister(CondTy);
6524-
B.buildNot(Inner, Cond);
6513+
auto Inner = B.buildNot(CondTy, Cond);
65256514
// Then an ext to match the destination register.
6526-
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
6527-
B.buildZExtOrTrunc(Ext, Inner);
6515+
auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
65286516
B.buildOr(DstReg, Ext, True, Flags);
65296517
};
65306518
return true;
@@ -6535,11 +6523,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
65356523
MatchInfo = [=](MachineIRBuilder &B) {
65366524
B.setInstrAndDebugLoc(*Select);
65376525
// First the not.
6538-
Register Inner = MRI.createGenericVirtualRegister(CondTy);
6539-
B.buildNot(Inner, Cond);
6526+
auto Inner = B.buildNot(CondTy, Cond);
65406527
// Then an ext to match the destination register.
6541-
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
6542-
B.buildZExtOrTrunc(Ext, Inner);
6528+
auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
65436529
B.buildAnd(DstReg, Ext, False);
65446530
};
65456531
return true;
@@ -6548,6 +6534,54 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
65486534
return false;
65496535
}
65506536

6537+
bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
6538+
BuildFnTy &MatchInfo) {
6539+
Register DstReg = Select->getReg(0);
6540+
Register Cond = Select->getCondReg();
6541+
Register False = Select->getFalseReg();
6542+
Register True = Select->getTrueReg();
6543+
LLT DstTy = MRI.getType(DstReg);
6544+
6545+
GBinOp *LHS = getOpcodeDef<GBinOp>(True, MRI);
6546+
GBinOp *RHS = getOpcodeDef<GBinOp>(False, MRI);
6547+
6548+
// We need two binops of the same kind on the true/false registers.
6549+
if (!LHS || !RHS || LHS->getOpcode() != RHS->getOpcode())
6550+
return false;
6551+
6552+
// Note that there are no constraints on CondTy.
6553+
unsigned Flags = (LHS->getFlags() & RHS->getFlags()) | Select->getFlags();
6554+
unsigned Opcode = LHS->getOpcode();
6555+
6556+
// Fold select(cond, binop(x, y), binop(z, y))
6557+
// --> binop(select(cond, x, z), y)
6558+
if (LHS->getRHSReg() == RHS->getRHSReg()) {
6559+
MatchInfo = [=](MachineIRBuilder &B) {
6560+
B.setInstrAndDebugLoc(*Select);
6561+
auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg(),
6562+
Select->getFlags());
6563+
B.buildInstr(Opcode, {DstReg}, {Sel, LHS->getRHSReg()}, Flags);
6564+
};
6565+
return true;
6566+
}
6567+
6568+
// Fold select(cond, binop(x, y), binop(x, z))
6569+
// --> binop(x, select(cond, y, z))
6570+
if (LHS->getLHSReg() == RHS->getLHSReg()) {
6571+
MatchInfo = [=](MachineIRBuilder &B) {
6572+
B.setInstrAndDebugLoc(*Select);
6573+
auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg(),
6574+
Select->getFlags());
6575+
B.buildInstr(Opcode, {DstReg}, {LHS->getLHSReg(), Sel}, Flags);
6576+
};
6577+
return true;
6578+
}
6579+
6580+
// FIXME: use isCommutable().
6581+
6582+
return false;
6583+
}
6584+
65516585
bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
65526586
GSelect *Select = cast<GSelect>(&MI);
65536587

@@ -6557,5 +6591,8 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
65576591
if (tryFoldBoolSelectToLogic(Select, MatchInfo))
65586592
return true;
65596593

6594+
if (tryFoldSelectOfBinOps(Select, MatchInfo))
6595+
return true;
6596+
65606597
return false;
65616598
}

0 commit comments

Comments
 (0)