Skip to content

Commit d2db014

Browse files
author
Thorsten Schütt
committed
[GlobalIsel] Combine ADDO
Perform the requested arithmetic and produce a carry output in addition to the normal result. Clang has them as builtins (__builtin_add_overflow_p). The middle end has intrinsics for them (sadd_with_overflow). AArch64: ADDS Add and set flags On Neoverse V2, they run at half the throughput of basic arithmetic and have a limited set of pipelines.
1 parent 2582965 commit d2db014

File tree

13 files changed

+1029
-752
lines changed

13 files changed

+1029
-752
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -696,10 +696,6 @@ class CombinerHelper {
696696
/// (G_*MULO x, 0) -> 0 + no carry out
697697
bool matchMulOBy0(MachineInstr &MI, BuildFnTy &MatchInfo);
698698

699-
/// Match:
700-
/// (G_*ADDO x, 0) -> x + no carry out
701-
bool matchAddOBy0(MachineInstr &MI, BuildFnTy &MatchInfo);
702-
703699
/// Match:
704700
/// (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
705701
/// (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
@@ -810,12 +806,15 @@ class CombinerHelper {
810806
/// Combine selects.
811807
bool matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo);
812808

813-
/// Combine ands,
809+
/// Combine ands.
814810
bool matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo);
815811

816-
/// Combine ors,
812+
/// Combine ors.
817813
bool matchOr(MachineInstr &MI, BuildFnTy &MatchInfo);
818814

815+
/// Combine addos.
816+
bool matchAddOverflow(MachineInstr &MI, BuildFnTy &MatchInfo);
817+
819818
private:
820819
/// Checks for legality of an indexed variant of \p LdSt.
821820
bool isIndexedLoadStoreLegal(GLoadStore &LdSt) const;
@@ -919,6 +918,7 @@ class CombinerHelper {
919918
bool isZeroOrZeroSplat(Register Src, bool AllowUndefs);
920919
bool isConstantSplatVector(Register Src, int64_t SplatValue,
921920
bool AllowUndefs);
921+
bool isConstantOrConstantVectorI(Register Src) const;
922922

923923
std::optional<APInt> getConstantOrConstantSplatVector(Register Src);
924924

llvm/include/llvm/CodeGen/GlobalISel/GenericMachineInstrs.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ class GBinOpCarryOut : public GenericMachineInstr {
359359
Register getCarryOutReg() const { return getReg(1); }
360360
MachineOperand &getLHS() { return getOperand(2); }
361361
MachineOperand &getRHS() { return getOperand(3); }
362+
Register getLHSReg() const { return getOperand(2).getReg(); }
363+
Register getRHSReg() const { return getOperand(3).getReg(); }
362364

363365
static bool classof(const MachineInstr *MI) {
364366
switch (MI->getOpcode()) {
@@ -429,6 +431,23 @@ class GAddSubCarryOut : public GBinOpCarryOut {
429431
}
430432
};
431433

434+
/// Represents overflowing add operations.
435+
/// G_UADDO, G_SADDO
436+
class GAddCarryOut : public GBinOpCarryOut {
437+
public:
438+
bool isSigned() const { return getOpcode() == TargetOpcode::G_SADDO; }
439+
440+
static bool classof(const MachineInstr *MI) {
441+
switch (MI->getOpcode()) {
442+
case TargetOpcode::G_UADDO:
443+
case TargetOpcode::G_SADDO:
444+
return true;
445+
default:
446+
return false;
447+
}
448+
}
449+
};
450+
432451
/// Represents overflowing add/sub operations that also consume a carry-in.
433452
/// G_UADDE, G_SADDE, G_USUBE, G_SSUBE
434453
class GAddSubCarryInOut : public GAddSubCarryOut {

llvm/include/llvm/Target/GlobalISel/Combine.td

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,12 +1090,6 @@ def mulo_by_0: GICombineRule<
10901090
[{ return Helper.matchMulOBy0(*${root}, ${matchinfo}); }]),
10911091
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
10921092

1093-
def addo_by_0: GICombineRule<
1094-
(defs root:$root, build_fn_matchinfo:$matchinfo),
1095-
(match (wip_match_opcode G_UADDO, G_SADDO):$root,
1096-
[{ return Helper.matchAddOBy0(*${root}, ${matchinfo}); }]),
1097-
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
1098-
10991093
// Transform (uadde x, y, 0) -> (uaddo x, y)
11001094
// (sadde x, y, 0) -> (saddo x, y)
11011095
// (usube x, y, 0) -> (usubo x, y)
@@ -1291,6 +1285,12 @@ def match_ors : GICombineRule<
12911285
[{ return Helper.matchOr(*${root}, ${matchinfo}); }]),
12921286
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
12931287

1288+
def match_addos : GICombineRule<
1289+
(defs root:$root, build_fn_matchinfo:$matchinfo),
1290+
(match (wip_match_opcode G_SADDO, G_UADDO):$root,
1291+
[{ return Helper.matchAddOverflow(*${root}, ${matchinfo}); }]),
1292+
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
1293+
12941294
// Combines concat operations
12951295
def concat_matchinfo : GIDefMatchData<"SmallVector<Register>">;
12961296
def combine_concat_vector : GICombineRule<
@@ -1326,7 +1326,7 @@ def identity_combines : GICombineGroup<[select_same_val, right_identity_zero,
13261326

13271327
def const_combines : GICombineGroup<[constant_fold_fp_ops, const_ptradd_to_i2p,
13281328
overlapping_and, mulo_by_2, mulo_by_0,
1329-
addo_by_0, adde_to_addo,
1329+
adde_to_addo,
13301330
combine_minmax_nan]>;
13311331

13321332
def known_bits_simplifications : GICombineGroup<[
@@ -1374,7 +1374,7 @@ def all_combines : GICombineGroup<[trivial_combines, insert_vec_elt_combines,
13741374
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
13751375
sub_add_reg, select_to_minmax, redundant_binop_in_equality,
13761376
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
1377-
combine_concat_vector, double_icmp_zero_and_or_combine]>;
1377+
combine_concat_vector, double_icmp_zero_and_or_combine, match_addos]>;
13781378

13791379
// A combine group used to for prelegalizer combiners at -O0. The combines in
13801380
// this group have been selected based on experiments to balance code size and

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 216 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4936,24 +4936,6 @@ bool CombinerHelper::matchMulOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) {
49364936
return true;
49374937
}
49384938

4939-
bool CombinerHelper::matchAddOBy0(MachineInstr &MI, BuildFnTy &MatchInfo) {
4940-
// (G_*ADDO x, 0) -> x + no carry out
4941-
assert(MI.getOpcode() == TargetOpcode::G_UADDO ||
4942-
MI.getOpcode() == TargetOpcode::G_SADDO);
4943-
if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(0)))
4944-
return false;
4945-
Register Carry = MI.getOperand(1).getReg();
4946-
if (!isConstantLegalOrBeforeLegalizer(MRI.getType(Carry)))
4947-
return false;
4948-
Register Dst = MI.getOperand(0).getReg();
4949-
Register LHS = MI.getOperand(2).getReg();
4950-
MatchInfo = [=](MachineIRBuilder &B) {
4951-
B.buildCopy(Dst, LHS);
4952-
B.buildConstant(Carry, 0);
4953-
};
4954-
return true;
4955-
}
4956-
49574939
bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, BuildFnTy &MatchInfo) {
49584940
// (G_*ADDE x, y, 0) -> (G_*ADDO x, y)
49594941
// (G_*SUBE x, y, 0) -> (G_*SUBO x, y)
@@ -6354,6 +6336,26 @@ CombinerHelper::getConstantOrConstantSplatVector(Register Src) {
63546336
return Value;
63556337
}
63566338

6339+
// FIXME G_SPLAT_VECTOR
6340+
bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const {
6341+
auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI);
6342+
if (IConstant)
6343+
return true;
6344+
6345+
GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI);
6346+
if (!BuildVector)
6347+
return false;
6348+
6349+
unsigned NumSources = BuildVector->getNumSources();
6350+
for (unsigned I = 0; I < NumSources; ++I) {
6351+
std::optional<ValueAndVReg> IConstant =
6352+
getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI);
6353+
if (!IConstant)
6354+
return false;
6355+
}
6356+
return true;
6357+
}
6358+
63576359
// TODO: use knownbits to determine zeros
63586360
bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select,
63596361
BuildFnTy &MatchInfo) {
@@ -6928,3 +6930,199 @@ bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) {
69286930

69296931
return false;
69306932
}
6933+
6934+
bool CombinerHelper::matchAddOverflow(MachineInstr &MI, BuildFnTy &MatchInfo) {
6935+
GAddCarryOut *Add = cast<GAddCarryOut>(&MI);
6936+
6937+
// Addo has no flags
6938+
Register Dst = Add->getReg(0);
6939+
Register Carry = Add->getReg(1);
6940+
Register LHS = Add->getLHSReg();
6941+
Register RHS = Add->getRHSReg();
6942+
bool IsSigned = Add->isSigned();
6943+
LLT DstTy = MRI.getType(Dst);
6944+
LLT CarryTy = MRI.getType(Carry);
6945+
6946+
// We want do fold the [u|s]addo.
6947+
if (!MRI.hasOneNonDBGUse(Dst))
6948+
return false;
6949+
6950+
// Fold addo, if the carry is dead -> add, undef.
6951+
if (MRI.use_nodbg_empty(Carry) &&
6952+
isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}})) {
6953+
MatchInfo = [=](MachineIRBuilder &B) {
6954+
B.buildAdd(Dst, LHS, RHS);
6955+
B.buildUndef(Carry);
6956+
};
6957+
return true;
6958+
}
6959+
6960+
// We want do fold the [u|s]addo.
6961+
if (!MRI.hasOneNonDBGUse(Carry))
6962+
return false;
6963+
6964+
// Canonicalize constant to RHS.
6965+
if (isConstantOrConstantVectorI(LHS) && !isConstantOrConstantVectorI(RHS)) {
6966+
if (IsSigned) {
6967+
MatchInfo = [=](MachineIRBuilder &B) {
6968+
B.buildSAddo(Dst, Carry, RHS, LHS);
6969+
};
6970+
return true;
6971+
} else {
6972+
MatchInfo = [=](MachineIRBuilder &B) {
6973+
B.buildUAddo(Dst, Carry, RHS, LHS);
6974+
};
6975+
return true;
6976+
}
6977+
}
6978+
6979+
std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(LHS);
6980+
std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(RHS);
6981+
6982+
// Fold addo(c1, c2) -> c3, carry.
6983+
if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(DstTy) &&
6984+
isConstantLegalOrBeforeLegalizer(CarryTy)) {
6985+
// They must both have the same bitwidth. Otherwise APInt might
6986+
// assert. Pre legalization, they may have widely different bitwidths.
6987+
unsigned BitWidth =
6988+
std::max(MaybeLHS->getBitWidth(), MaybeRHS->getBitWidth());
6989+
bool Overflow;
6990+
APInt Result;
6991+
if (IsSigned) {
6992+
APInt LHS = MaybeLHS->sext(BitWidth);
6993+
APInt RHS = MaybeRHS->sext(BitWidth);
6994+
Result = LHS.sadd_ov(RHS, Overflow);
6995+
} else {
6996+
APInt LHS = MaybeLHS->zext(BitWidth);
6997+
APInt RHS = MaybeRHS->zext(BitWidth);
6998+
Result = LHS.uadd_ov(RHS, Overflow);
6999+
}
7000+
MatchInfo = [=](MachineIRBuilder &B) {
7001+
B.buildConstant(Dst, Result);
7002+
B.buildConstant(Carry, Overflow);
7003+
};
7004+
return true;
7005+
}
7006+
7007+
// Fold (addo x, 0) -> x, no borrow
7008+
if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(CarryTy)) {
7009+
MatchInfo = [=](MachineIRBuilder &B) {
7010+
B.buildCopy(Dst, LHS);
7011+
B.buildConstant(Carry, 0);
7012+
};
7013+
return true;
7014+
}
7015+
7016+
// Given 2 constant operands whose sum does not overflow:
7017+
// uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1
7018+
// saddo (X +nsw C0), C1 -> saddo X, C0 + C1
7019+
GAdd *AddLHS = getOpcodeDef<GAdd>(LHS, MRI);
7020+
if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(Add->getReg(0)) &&
7021+
((IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoSWrap)) ||
7022+
(!IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoUWrap)))) {
7023+
std::optional<APInt> MaybeAddRHS =
7024+
getConstantOrConstantSplatVector(AddLHS->getRHSReg());
7025+
if (MaybeAddRHS) {
7026+
unsigned BitWidth =
7027+
std::max(MaybeRHS->getBitWidth(), MaybeAddRHS->getBitWidth());
7028+
bool Overflow;
7029+
APInt NewC;
7030+
if (IsSigned) {
7031+
APInt LHS = MaybeRHS->sext(BitWidth);
7032+
APInt RHS = MaybeAddRHS->sext(BitWidth);
7033+
NewC = LHS.sadd_ov(RHS, Overflow);
7034+
} else {
7035+
APInt LHS = MaybeRHS->zext(BitWidth);
7036+
APInt RHS = MaybeAddRHS->zext(BitWidth);
7037+
NewC = LHS.uadd_ov(RHS, Overflow);
7038+
}
7039+
if (!Overflow && isConstantLegalOrBeforeLegalizer(DstTy)) {
7040+
if (IsSigned) {
7041+
MatchInfo = [=](MachineIRBuilder &B) {
7042+
auto ConstRHS = B.buildConstant(DstTy, NewC);
7043+
B.buildSAddo(Dst, Carry, AddLHS->getLHSReg(), ConstRHS);
7044+
};
7045+
return true;
7046+
} else {
7047+
MatchInfo = [=](MachineIRBuilder &B) {
7048+
auto ConstRHS = B.buildConstant(DstTy, NewC);
7049+
B.buildUAddo(Dst, Carry, AddLHS->getLHSReg(), ConstRHS);
7050+
};
7051+
return true;
7052+
}
7053+
}
7054+
}
7055+
};
7056+
7057+
// We try to combine uaddo to non-overflowing add.
7058+
if (!IsSigned && isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}}) &&
7059+
isConstantLegalOrBeforeLegalizer(DstTy)) {
7060+
ConstantRange CRLHS =
7061+
ConstantRange::fromKnownBits(KB->getKnownBits(LHS), false /*IsSigned*/);
7062+
ConstantRange CRRHS =
7063+
ConstantRange::fromKnownBits(KB->getKnownBits(RHS), false /*IsSigned*/);
7064+
7065+
switch (CRLHS.unsignedAddMayOverflow(CRRHS)) {
7066+
case ConstantRange::OverflowResult::MayOverflow:
7067+
return false;
7068+
case ConstantRange::OverflowResult::NeverOverflows: {
7069+
MatchInfo = [=](MachineIRBuilder &B) {
7070+
B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoUWrap);
7071+
B.buildConstant(Carry, 0);
7072+
};
7073+
return true;
7074+
}
7075+
case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7076+
case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
7077+
MatchInfo = [=](MachineIRBuilder &B) {
7078+
B.buildAdd(Dst, LHS, RHS);
7079+
B.buildConstant(Carry, 1);
7080+
};
7081+
return true;
7082+
}
7083+
};
7084+
return false;
7085+
};
7086+
7087+
// We try to combine saddo to non-overflowing add.
7088+
if (!isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}}) ||
7089+
!isConstantLegalOrBeforeLegalizer(CarryTy))
7090+
return false;
7091+
7092+
// If LHS and RHS each have at least two sign bits, then there is no signed
7093+
// overflow.
7094+
if (KB->computeNumSignBits(LHS) > 1 && KB->computeNumSignBits(RHS) > 1) {
7095+
MatchInfo = [=](MachineIRBuilder &B) {
7096+
B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap);
7097+
B.buildConstant(Carry, 0);
7098+
};
7099+
return true;
7100+
}
7101+
7102+
ConstantRange CRLHS =
7103+
ConstantRange::fromKnownBits(KB->getKnownBits(LHS), true /*IsSigned*/);
7104+
ConstantRange CRRHS =
7105+
ConstantRange::fromKnownBits(KB->getKnownBits(RHS), true /*IsSigned*/);
7106+
7107+
switch (CRLHS.signedAddMayOverflow(CRRHS)) {
7108+
case ConstantRange::OverflowResult::MayOverflow:
7109+
return false;
7110+
case ConstantRange::OverflowResult::NeverOverflows: {
7111+
MatchInfo = [=](MachineIRBuilder &B) {
7112+
B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap);
7113+
B.buildConstant(Carry, 0);
7114+
};
7115+
return true;
7116+
}
7117+
case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7118+
case ConstantRange::OverflowResult::AlwaysOverflowsHigh: {
7119+
MatchInfo = [=](MachineIRBuilder &B) {
7120+
B.buildAdd(Dst, LHS, RHS);
7121+
B.buildConstant(Carry, 1);
7122+
};
7123+
return true;
7124+
}
7125+
};
7126+
7127+
return false;
7128+
}

0 commit comments

Comments
 (0)