Skip to content

[GlobalIsel] Combine ADDO #82927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -696,10 +696,6 @@ class CombinerHelper {
/// (G_*MULO x, 0) -> 0 + no carry out
bool matchMulOBy0(MachineInstr &MI, BuildFnTy &MatchInfo);

/// Match:
/// (G_*ADDO x, 0) -> x + no carry out
bool matchAddOBy0(MachineInstr &MI, BuildFnTy &MatchInfo);

/// Match:
/// (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
/// (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
Expand Down Expand Up @@ -810,12 +806,15 @@ class CombinerHelper {
/// Combine selects.
bool matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo);

/// Combine ands,
/// Combine ands.
bool matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo);

/// Combine ors,
/// Combine ors.
bool matchOr(MachineInstr &MI, BuildFnTy &MatchInfo);

/// Combine addos.
bool matchAddOverflow(MachineInstr &MI, BuildFnTy &MatchInfo);

private:
/// Checks for legality of an indexed variant of \p LdSt.
bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
Expand Down Expand Up @@ -919,6 +918,7 @@ class CombinerHelper {
bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
bool isConstantSplatVector(Register Src, int64_t SplatValue,
bool AllowUndefs);
bool isConstantOrConstantVectorI(Register Src) const;

std::optional<APInt> getConstantOrConstantSplatVector(Register Src);

Expand Down
19 changes: 19 additions & 0 deletions llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ class GBinOpCarryOut : public GenericMachineInstr {
Register getCarryOutReg() const { return getReg(1); }
MachineOperand &getLHS() { return getOperand(2); }
MachineOperand &getRHS() { return getOperand(3); }
Register getLHSReg() const { return getOperand(2).getReg(); }
Register getRHSReg() const { return getOperand(3).getReg(); }

static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
Expand Down Expand Up @@ -429,6 +431,23 @@ class GAddSubCarryOut : public GBinOpCarryOut {
}
};

/// Represents overflowing add operations.
/// G_UADDO, G_SADDO
class GAddCarryOut : public GBinOpCarryOut {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an odd interference with the more general GAddSubCarryOut. Do you really need this class?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The common pattern is to assert that only the expected opcode is in MI. I use cast<GAddCarryOut>. I don't want unnoticed sub to come in.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GAddSubCarryOut has a isSub() method for that purpose.

public:
bool isSigned() const { return getOpcode() == TargetOpcode::G_SADDO; }

static bool classof(const MachineInstr *MI) {
switch (MI->getOpcode()) {
case TargetOpcode::G_UADDO:
case TargetOpcode::G_SADDO:
return true;
default:
return false;
}
}
};

/// Represents overflowing add/sub operations that also consume a carry-in.
/// G_UADDE, G_SADDE, G_USUBE, G_SSUBE
class GAddSubCarryInOut : public GAddSubCarryOut {
Expand Down
16 changes: 8 additions & 8 deletions llvm/include/llvm/Target/GlobalISel/Combine.td
Original file line number Diff line number Diff line change
Expand Up @@ -1090,12 +1090,6 @@ def mulo_by_0: GICombineRule<
[{ return Helper.matchMulOBy0(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

def addo_by_0: GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_UADDO, G_SADDO):$root,
[{ return Helper.matchAddOBy0(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// Transform (uadde x, y, 0) -> (uaddo x, y)
// (sadde x, y, 0) -> (saddo x, y)
// (usube x, y, 0) -> (usubo x, y)
Expand Down Expand Up @@ -1291,6 +1285,12 @@ def match_ors : GICombineRule<
[{ return Helper.matchOr(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

def match_addos : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_SADDO, G_UADDO):$root,
[{ return Helper.matchAddOverflow(*${root}, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know about the current direction, but if we're going to tablegen as much as possible, then we would need separate rules for each of the transformations (constant folding / swapping constant to RHS / replacing with G_ADD etc.).

// Combines concat operations
def concat_matchinfo : GIDefMatchData<"SmallVector<Register>">;
def combine_concat_vector : GICombineRule<
Expand Down Expand Up @@ -1326,7 +1326,7 @@ def identity_combines : GICombineGroup<[select_same_val, right_identity_zero,

def const_combines : GICombineGroup<[constant_fold_fp_ops, const_ptradd_to_i2p,
overlapping_and, mulo_by_2, mulo_by_0,
addo_by_0, adde_to_addo,
adde_to_addo,
combine_minmax_nan]>;

def known_bits_simplifications : GICombineGroup<[
Expand Down Expand Up @@ -1374,7 +1374,7 @@ def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
sub_add_reg, select_to_minmax, redundant_binop_in_equality,
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
combine_concat_vector, double_icmp_zero_and_or_combine]>;
combine_concat_vector, double_icmp_zero_and_or_combine, match_addos]>;

// A combine group used to for prelegalizer combiners at -O0. The combines in
// this group have been selected based on experiments to balance code size and
Expand Down
213 changes: 195 additions & 18 deletions llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4936,24 +4936,6 @@ bool CombinerHelper::matchMulOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) {
return true;
}

bool CombinerHelper::matchAddOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) {
// (G_*ADDO x, 0) -> x + no carry out
assert(MI.getOpcode() == TargetOpcode::G_UADDO ||
MI.getOpcode() == TargetOpcode::G_SADDO);
if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(0)))
return false;
Register Carry = MI.getOperand(1).getReg();
if (!isConstantLegalOrBeforeLegalizer(MRI.getType(Carry)))
return false;
Register Dst = MI.getOperand(0).getReg();
Register LHS = MI.getOperand(2).getReg();
MatchInfo = [=](MachineIRBuilder &B) {
B.buildCopy(Dst, LHS);
B.buildConstant(Carry, 0);
};
return true;
}

bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, BuildFnTy &MatchInfo) {
// (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
// (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
Expand Down Expand Up @@ -6354,6 +6336,26 @@ CombinerHelper::getConstantOrConstantSplatVector(Register Src) {
return Value;
}

// FIXME G_SPLAT_VECTOR
bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const {
auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
if (IConstant)
return true;

GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
if (!BuildVector)
return false;

unsigned NumSources = BuildVector->getNumSources();
for (unsigned I = 0; I < NumSources; ++I) {
std::optional<ValueAndVReg> IConstant =
getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
if (!IConstant)
return false;
}
return true;
}

// TODO: use knownbits to determine zeros
bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
BuildFnTy &MatchInfo) {
Expand Down Expand Up @@ -6928,3 +6930,178 @@ bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) {

return false;
}

bool CombinerHelper::matchAddOverflow(MachineInstr &MI, BuildFnTy &MatchInfo) {
GAddCarryOut *Add = cast<GAddCarryOut>(&MI);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing this line with:

GAddSubCarryOut *Add = cast<GAddSubCarryOut>(&MI);

seems dangerous and defeats the purpose of the assert in the cast .


// Addo has no flags
Register Dst = Add->getReg(0);
Register Carry = Add->getReg(1);
Register LHS = Add->getLHSReg();
Register RHS = Add->getRHSReg();
bool IsSigned = Add->isSigned();
LLT DstTy = MRI.getType(Dst);
LLT CarryTy = MRI.getType(Carry);

// We want do fold the [u|s]addo.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo "want to", but the comment seems unrelated to the code. Why would multiple uses of the top level instruction prevent combining?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the result Dst has multiple uses, then we cannot replace it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can and we should. The combine will still be correct and beneficial. You can add tests for the multiple-use case.

As a general rule, if you are matching a pattern, you only need one-use checks for the inner nodes in the pattern, not the outer node. The reason for the checks is to ensure that when you remove the outer node, the inner nodes also get removed.

if (!MRI.hasOneNonDBGUse(Dst))
return false;

// Fold addo, if the carry is dead -> add, undef.
if (MRI.use_nodbg_empty(Carry) &&
isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}})) {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildAdd(Dst, LHS, RHS);
B.buildUndef(Carry);
};
return true;
}

// We want do fold the [u|s]addo.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise.

if (!MRI.hasOneNonDBGUse(Carry))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not obvious why multiple uses prevent folding / make it unprofitable. Could you clarify the comment?
Same for Dst above.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure will do. If we have more than one use, then we are generating new instructions and keep the old ones.

return false;

// Canonicalize constant to RHS.
if (isConstantOrConstantVectorI(LHS) && !isConstantOrConstantVectorI(RHS)) {
if (IsSigned) {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildSAddo(Dst, Carry, RHS, LHS);
};
return true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: personally I think early return hurts symmetry here. I would prefer if/else followed by a single "return".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The anti-symmetry is a result of another else-after-return violation.

}
// !IsSigned
MatchInfo = [=](MachineIRBuilder &B) {
B.buildUAddo(Dst, Carry, RHS, LHS);
};
return true;
}

std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(LHS);
std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(RHS);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

early exit after RHS, then LHS

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can still try the known bits optimizations, even when both are std::nullopt.


// Fold addo(c1, c2) -> c3, carry.
if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(DstTy) &&
isConstantLegalOrBeforeLegalizer(CarryTy)) {
bool Overflow;
APInt Result = IsSigned ? MaybeLHS->sadd_ov(*MaybeRHS, Overflow)
: MaybeLHS->uadd_ov(*MaybeRHS, Overflow);
MatchInfo = [=](MachineIRBuilder &B) {
B.buildConstant(Dst, Result);
B.buildConstant(Carry, Overflow);
Comment on lines +6989 to +6990
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the vector-typed case don't you need to build new splat constants here? The patch is missing vector-typed tests for all these folds.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AArch64 does not support vectorized overflow ops. buildConstant builds under the hood scalars or build vectors. Support for G_SPLAT_VECTOR is still missing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AArch64 does not support vectorized overflow ops.

You should still add the tests. Your combine-overflow.mir test runs pre-legalization so any MIR should be allowed there.

buildConstant builds under the hood scalars or build vectors.

Ah! I didn't know that, thanks.

};
return true;
}

// Fold (addo x, 0) -> x, no borrow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: "carry" not "borrow"

if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(CarryTy)) {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildCopy(Dst, LHS);
B.buildConstant(Carry, 0);
};
return true;
}
Comment on lines +6995 to +7002
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be moved to pure tablegen as a separate pattern


// Given 2 constant operands whose sum does not overflow:
// uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1
// saddo (X +nsw C0), C1 -> saddo X, C0 + C1
GAdd *AddLHS = getOpcodeDef<GAdd>(LHS, MRI);
if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(Add->getReg(0)) &&
((IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoSWrap)) ||
(!IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoUWrap)))) {
Comment on lines +7009 to +7010
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
((IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoSWrap)) ||
(!IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoUWrap)))) {
AddLHS->getFlag(IsSigned ? MachineInstr::MIFlag::NoSWrap : MachineInstr::MIFlag::NoUWrap)) {

std::optional<APInt> MaybeAddRHS =
getConstantOrConstantSplatVector(AddLHS->getRHSReg());
if (MaybeAddRHS) {
bool Overflow;
APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(*MaybeRHS, Overflow)
: MaybeAddRHS->uadd_ov(*MaybeRHS, Overflow);
if (!Overflow && isConstantLegalOrBeforeLegalizer(DstTy)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to check isConstantLegalOrBeforeLegalizer(DstTy) because it's the same type as MaybeRHS and MaybeAddRHS which we already know are constants.

if (IsSigned) {
MatchInfo = [=](MachineIRBuilder &B) {
auto ConstRHS = B.buildConstant(DstTy, NewC);
B.buildSAddo(Dst, Carry, AddLHS->getLHSReg(), ConstRHS);
};
return true;
}
// !IsSigned
MatchInfo = [=](MachineIRBuilder &B) {
auto ConstRHS = B.buildConstant(DstTy, NewC);
B.buildUAddo(Dst, Carry, AddLHS->getLHSReg(), ConstRHS);
};
return true;
}
}
};

// We try to combine addo to non-overflowing add.
if (!isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}}) ||
!isConstantLegalOrBeforeLegalizer(CarryTy))
return false;

// We try to combine uaddo to non-overflowing add.
if (!IsSigned) {
ConstantRange CRLHS =
ConstantRange::fromKnownBits(KB->getKnownBits(LHS), /*IsSigned=*/false);
ConstantRange CRRHS =
ConstantRange::fromKnownBits(KB->getKnownBits(RHS), /*IsSigned=*/false);

switch (CRLHS.unsignedAddMayOverflow(CRRHS)) {
case ConstantRange::OverflowResult::MayOverflow:
return false;
case ConstantRange::OverflowResult::NeverOverflows: {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoUWrap);
B.buildConstant(Carry, 0);
};
return true;
}
case ConstantRange::OverflowResult::AlwaysOverflowsLow:
case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildAdd(Dst, LHS, RHS);
B.buildConstant(Carry, 1);
};
return true;
}
}
return false;
}

// We try to combine saddo to non-overflowing add.

// If LHS and RHS each have at least two sign bits, then there is no signed
// overflow.
if (KB->computeNumSignBits(RHS) > 1 && KB->computeNumSignBits(LHS) > 1) {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap);
B.buildConstant(Carry, 0);
};
return true;
}

ConstantRange CRLHS =
ConstantRange::fromKnownBits(KB->getKnownBits(LHS), /*IsSigned=*/true);
ConstantRange CRRHS =
ConstantRange::fromKnownBits(KB->getKnownBits(RHS), /*IsSigned=*/true);

switch (CRLHS.signedAddMayOverflow(CRRHS)) {
case ConstantRange::OverflowResult::MayOverflow:
return false;
case ConstantRange::OverflowResult::NeverOverflows: {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap);
B.buildConstant(Carry, 0);
};
return true;
}
case ConstantRange::OverflowResult::AlwaysOverflowsLow:
case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
MatchInfo = [=](MachineIRBuilder &B) {
B.buildAdd(Dst, LHS, RHS);
B.buildConstant(Carry, 1);
};
return true;
}
}

return false;
}
Loading