Skip to content

[GlobalIsel] Combine select of binops #76763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,9 @@ class CombinerHelper {

bool tryFoldSelectOfConstants(GSelect *Select, BuildFnTy &MatchInfo);

/// Try to fold select(cc, binop(), binop()) -> binop(select(), X)
bool tryFoldSelectOfBinOps(GSelect *Select, BuildFnTy &MatchInfo);

bool isOneOrOneSplat(Register Src, bool AllowUndefs);
bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
bool isConstantSplatVector(Register Src, int64_t SplatValue,
Expand Down
103 changes: 103 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,109 @@ class GVecReduce : public GenericMachineInstr {
}
};

// Represents a binary operation, i.e, x = y op z.
class GBinOp : public GenericMachineInstr {
public:
Register getLHSReg() const { return getReg(1); }
Register getRHSReg() const { return getReg(2); }

static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you implement this in terms of GIntBinOp::classof and GFBinOp::classof to avoid listing the opcodes in multiple places?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to suggest that too but it won't compile due to not yet seeing the definition of GIntBinOp when parsing this class.

// Integer.
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_SMIN:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMIN:
case TargetOpcode::G_UMAX:
// Floating point.
case TargetOpcode::G_FMINNUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXIMUM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FPOW:
// Logical.
case TargetOpcode::G_AND:
case TargetOpcode::G_OR:
case TargetOpcode::G_XOR:
return true;
default:
return false;
}
};
};

// Represents an integer binary operation.
class GIntBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_MUL:
case TargetOpcode::G_SDIV:
case TargetOpcode::G_UDIV:
case TargetOpcode::G_SREM:
case TargetOpcode::G_UREM:
case TargetOpcode::G_SMIN:
case TargetOpcode::G_SMAX:
case TargetOpcode::G_UMIN:
case TargetOpcode::G_UMAX:
return true;
default:
return false;
}
};
};

// Represents a floating point binary operation.
class GFBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_FMINNUM:
case TargetOpcode::G_FMAXNUM:
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
case TargetOpcode::G_FMINIMUM:
case TargetOpcode::G_FMAXIMUM:
case TargetOpcode::G_FADD:
case TargetOpcode::G_FSUB:
case TargetOpcode::G_FMUL:
case TargetOpcode::G_FDIV:
case TargetOpcode::G_FPOW:
return true;
default:
return false;
}
};
};

// Represents a logical binary operation.
class GLogicalBinOp : public GBinOp {
public:
static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_AND:
case TargetOpcode::G_OR:
case TargetOpcode::G_XOR:
return true;
default:
return false;
}
};
};

} // namespace llvm

Expand Down
93 changes: 65 additions & 28 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6390,8 +6390,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isZero() && FalseValue.isOne()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
auto Inner = B.buildNot(CondTy, Cond);
B.buildZExtOrTrunc(Dest, Inner);
};
return true;
Expand All @@ -6401,8 +6400,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isZero() && FalseValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
auto Inner = B.buildNot(CondTy, Cond);
B.buildSExtOrTrunc(Dest, Inner);
};
return true;
Expand All @@ -6412,8 +6410,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue - 1 == FalseValue) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Inner, Cond);
auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
B.buildAdd(Dest, Inner, False);
};
return true;
Expand All @@ -6423,8 +6420,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue + 1 == FalseValue) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildSExtOrTrunc(Inner, Cond);
auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
B.buildAdd(Dest, Inner, False);
};
return true;
Expand All @@ -6434,8 +6430,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isPowerOf2() && FalseValue.isZero()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Inner, Cond);
auto Inner = B.buildZExtOrTrunc(TrueTy, Cond);
// The shift amount must be scalar.
LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy;
auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2());
Expand All @@ -6447,8 +6442,7 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (TrueValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildSExtOrTrunc(Inner, Cond);
auto Inner = B.buildSExtOrTrunc(TrueTy, Cond);
B.buildOr(Dest, Inner, False, Flags);
};
return true;
Expand All @@ -6458,10 +6452,8 @@ bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
if (FalseValue.isAllOnes()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Not = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Not, Cond);
Register Inner = MRI.createGenericVirtualRegister(TrueTy);
B.buildSExtOrTrunc(Inner, Not);
auto Not = B.buildNot(CondTy, Cond);
auto Inner = B.buildSExtOrTrunc(TrueTy, Not);
B.buildOr(Dest, Inner, True, Flags);
};
return true;
Expand Down Expand Up @@ -6496,8 +6488,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Cond);
auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
B.buildOr(DstReg, Ext, False, Flags);
};
return true;
Expand All @@ -6508,8 +6499,7 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Cond);
auto Ext = B.buildZExtOrTrunc(TrueTy, Cond);
B.buildAnd(DstReg, Ext, True);
};
return true;
Expand All @@ -6520,11 +6510,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
// First the not.
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
auto Inner = B.buildNot(CondTy, Cond);
// Then an ext to match the destination register.
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Inner);
auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
B.buildOr(DstReg, Ext, True, Flags);
};
return true;
Expand All @@ -6535,11 +6523,9 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
// First the not.
Register Inner = MRI.createGenericVirtualRegister(CondTy);
B.buildNot(Inner, Cond);
auto Inner = B.buildNot(CondTy, Cond);
// Then an ext to match the destination register.
Register Ext = MRI.createGenericVirtualRegister(TrueTy);
B.buildZExtOrTrunc(Ext, Inner);
auto Ext = B.buildZExtOrTrunc(TrueTy, Inner);
B.buildAnd(DstReg, Ext, False);
};
return true;
Expand All @@ -6548,6 +6534,54 @@ bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select,
return false;
}

bool CombinerHelper::tryFoldSelectOfBinOps(GSelect *Select,
BuildFnTy &MatchInfo) {
Register DstReg = Select->getReg(0);
Register Cond = Select->getCondReg();
Register False = Select->getFalseReg();
Register True = Select->getTrueReg();
LLT DstTy = MRI.getType(DstReg);

GBinOp *LHS = getOpcodeDef<GBinOp>(True, MRI);
GBinOp *RHS = getOpcodeDef<GBinOp>(False, MRI);

// We need two binops of the same kind on the true/false registers.
if (!LHS || !RHS || LHS->getOpcode() != RHS->getOpcode())
return false;

// Note that there are no constraints on CondTy.
unsigned Flags = (LHS->getFlags() & RHS->getFlags()) | Select->getFlags();
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line differs from the Dag combiner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have split this by matching the DAG behavior in the initial commit but doesn't really matter. I think this is OK but would be nice to have alive verify

unsigned Opcode = LHS->getOpcode();

// Fold select(cond, binop(x, y), binop(z, y))
// --> binop(select(cond, x, z), y)
if (LHS->getRHSReg() == RHS->getRHSReg()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Sel = B.buildSelect(DstTy, Cond, LHS->getLHSReg(), RHS->getLHSReg(),
Select->getFlags());
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And Flags are added to selects.

B.buildInstr(Opcode, {DstReg}, {Sel, LHS->getRHSReg()}, Flags);
};
return true;
}

// Fold select(cond, binop(x, y), binop(x, z))
// --> binop(x, select(cond, y, z))
if (LHS->getLHSReg() == RHS->getLHSReg()) {
MatchInfo = [=](MachineIRBuilder &B) {
B.setInstrAndDebugLoc(*Select);
auto Sel = B.buildSelect(DstTy, Cond, LHS->getRHSReg(), RHS->getRHSReg(),
Select->getFlags());
B.buildInstr(Opcode, {DstReg}, {LHS->getLHSReg(), Sel}, Flags);
};
return true;
}

// FIXME: use isCommutable().

return false;
}

bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
GSelect *Select = cast<GSelect>(&MI);

Expand All @@ -6557,5 +6591,8 @@ bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) {
if (tryFoldBoolSelectToLogic(Select, MatchInfo))
return true;

if (tryFoldSelectOfBinOps(Select, MatchInfo))
return true;

return false;
}
Loading