Skip to content

Commit b3f6a01

Browse files
committed
[AArch64] Eliminate Common Subexpression of CSEL by Reassociation
If we have a CSEL instruction that depends on the flags set by a (SUBS x c) instruction and the true and/or false expression is (add (add x y) -c), we can reassociate the latter expression to (add (SUBS x c) y) and save one instruction. The transformation works for unsigned comparisons and equality comparisons with 0 (by converting them to unsigned comparisons). Proof for the basic transformation: https://alive2.llvm.org/ce/z/-337Pb Fixes #119606.
1 parent e91c3b4 commit b3f6a01

File tree

2 files changed

+115
-47
lines changed

2 files changed

+115
-47
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24838,6 +24838,83 @@ static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) {
2483824838
return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond);
2483924839
}
2484024840

24841+
// Reassociate the true/false expressions of a CSEL instruction to obtain a
24842+
// common subexpression with the comparison instruction. For example, change
24843+
// (CSEL (ADD (ADD x y) -c) f LO (SUBS x c)) to
24844+
// (CSEL (ADD (SUBS x c) y) f LO (SUBS x c)) such that (SUBS x c) is a common
24845+
// subexpression.
24846+
static SDValue reassociateCSELOperandsForCSE(SDNode *N, SelectionDAG &DAG) {
24847+
SDValue SubsNode = N->getOperand(3);
24848+
if (SubsNode.getOpcode() != AArch64ISD::SUBS || !SubsNode.hasOneUse())
24849+
return SDValue();
24850+
auto *CmpOpConst = dyn_cast<ConstantSDNode>(SubsNode.getOperand(1));
24851+
if (!CmpOpConst)
24852+
return SDValue();
24853+
24854+
auto CC = static_cast<AArch64CC::CondCode>(N->getConstantOperandVal(2));
24855+
bool IsEquality = CC == AArch64CC::EQ || CC == AArch64CC::NE;
24856+
if (CC != AArch64CC::LO && CC != AArch64CC::HI &&
24857+
(!IsEquality || !CmpOpConst->isZero()))
24858+
return SDValue();
24859+
// The cases (x < c) and (x == 0) are later unified as (x < newconst).
24860+
// The cases (x > c) and (x != 0) are later unified as (x >= newconst).
24861+
APInt NewCmpConst = CC == AArch64CC::LO ? CmpOpConst->getAPIntValue()
24862+
: CmpOpConst->getAPIntValue() + 1;
24863+
APInt ExpectedConst = -NewCmpConst;
24864+
24865+
SDValue CmpOpOther = SubsNode.getOperand(0);
24866+
EVT VT = N->getValueType(0);
24867+
SDValue NewCmp = DAG.getNode(AArch64ISD::SUBS, SDLoc(SubsNode),
24868+
DAG.getVTList(VT, MVT_CC), CmpOpOther,
24869+
DAG.getConstant(NewCmpConst, SDLoc(CmpOpConst),
24870+
CmpOpConst->getValueType(0)));
24871+
24872+
auto Reassociate = [&](SDValue Op) {
24873+
if (Op.getOpcode() != ISD::ADD)
24874+
return SDValue();
24875+
auto *AddOpConst = dyn_cast<ConstantSDNode>(Op.getOperand(1));
24876+
if (!AddOpConst)
24877+
return SDValue();
24878+
if (IsEquality && AddOpConst->getAPIntValue() != ExpectedConst)
24879+
return SDValue();
24880+
if (!IsEquality && AddOpConst->getAPIntValue() != ExpectedConst)
24881+
return SDValue();
24882+
if (Op.getOperand(0).getOpcode() != ISD::ADD ||
24883+
!Op.getOperand(0).hasOneUse())
24884+
return SDValue();
24885+
SDValue X = Op.getOperand(0).getOperand(0);
24886+
SDValue Y = Op.getOperand(0).getOperand(1);
24887+
if (X != CmpOpOther)
24888+
std::swap(X, Y);
24889+
if (X != CmpOpOther)
24890+
return SDValue();
24891+
SDNodeFlags Flags;
24892+
if (Op.getOperand(0).getNode()->getFlags().hasNoUnsignedWrap())
24893+
Flags.setNoUnsignedWrap(true);
24894+
return DAG.getNode(ISD::ADD, SDLoc(Op), VT, NewCmp.getValue(0), Y, Flags);
24895+
};
24896+
24897+
SDValue TValReassoc = Reassociate(N->getOperand(0));
24898+
SDValue FValReassoc = Reassociate(N->getOperand(1));
24899+
if (!TValReassoc && !FValReassoc)
24900+
return SDValue();
24901+
if (TValReassoc)
24902+
DAG.ReplaceAllUsesWith(N->getOperand(0), TValReassoc);
24903+
else
24904+
TValReassoc = N->getOperand(0);
24905+
if (FValReassoc)
24906+
DAG.ReplaceAllUsesWith(N->getOperand(1), FValReassoc);
24907+
else
24908+
FValReassoc = N->getOperand(1);
24909+
24910+
AArch64CC::CondCode NewCC = CC == AArch64CC::EQ || CC == AArch64CC::LO
24911+
? AArch64CC::LO
24912+
: AArch64CC::HS;
24913+
return DAG.getNode(AArch64ISD::CSEL, SDLoc(N), VT, TValReassoc, FValReassoc,
24914+
DAG.getConstant(NewCC, SDLoc(N->getOperand(2)), MVT_CC),
24915+
NewCmp.getValue(1));
24916+
}
24917+
2484124918
// Optimize CSEL instructions
2484224919
static SDValue performCSELCombine(SDNode *N,
2484324920
TargetLowering::DAGCombinerInfo &DCI,
@@ -24849,6 +24926,11 @@ static SDValue performCSELCombine(SDNode *N,
2484924926
if (SDValue R = foldCSELOfCSEL(N, DAG))
2485024927
return R;
2485124928

24929+
// Try to reassociate the true/false expressions so that we can do CSE with
24930+
// a SUBS instruction used to perform the comparison.
24931+
if (SDValue R = reassociateCSELOperandsForCSE(N, DAG))
24932+
return R;
24933+
2485224934
// CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1
2485324935
// CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1
2485424936
if (SDValue Folded = foldCSELofCTTZ(N, DAG))

llvm/test/CodeGen/AArch64/csel-cmp-cse.ll

Lines changed: 33 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ declare void @use_i32(i32 %x)
88
define ptr @test_last_elem_from_ptr(ptr noundef readnone %x0, i64 noundef %x1) {
99
; CHECK-LABEL: test_last_elem_from_ptr:
1010
; CHECK: // %bb.0:
11-
; CHECK-NEXT: add x8, x0, x1
12-
; CHECK-NEXT: cmp x1, #0
13-
; CHECK-NEXT: sub x8, x8, #1
14-
; CHECK-NEXT: csel x0, xzr, x8, eq
11+
; CHECK-NEXT: subs x8, x1, #1
12+
; CHECK-NEXT: add x8, x8, x0
13+
; CHECK-NEXT: csel x0, xzr, x8, lo
1514
; CHECK-NEXT: ret
1615
%cmp = icmp eq i64 %x1, 0
1716
%add.ptr = getelementptr inbounds nuw i8, ptr %x0, i64 %x1
@@ -23,10 +22,9 @@ define ptr @test_last_elem_from_ptr(ptr noundef readnone %x0, i64 noundef %x1) {
2322
define i32 @test_eq0_sub_add_i32(i32 %x0, i32 %x1) {
2423
; CHECK-LABEL: test_eq0_sub_add_i32:
2524
; CHECK: // %bb.0:
26-
; CHECK-NEXT: add w8, w0, w1
27-
; CHECK-NEXT: cmp w1, #0
28-
; CHECK-NEXT: sub w8, w8, #1
29-
; CHECK-NEXT: csel w0, wzr, w8, eq
25+
; CHECK-NEXT: subs w8, w1, #1
26+
; CHECK-NEXT: add w8, w8, w0
27+
; CHECK-NEXT: csel w0, wzr, w8, lo
3028
; CHECK-NEXT: ret
3129
%cmp = icmp eq i32 %x1, 0
3230
%add = add nuw i32 %x0, %x1
@@ -38,9 +36,8 @@ define i32 @test_eq0_sub_add_i32(i32 %x0, i32 %x1) {
3836
define i32 @test_ule7_sub_add_i32(i32 %x0, i32 %x1) {
3937
; CHECK-LABEL: test_ule7_sub_add_i32:
4038
; CHECK: // %bb.0:
41-
; CHECK-NEXT: add w8, w0, w1
42-
; CHECK-NEXT: cmp w1, #8
43-
; CHECK-NEXT: sub w8, w8, #8
39+
; CHECK-NEXT: subs w8, w1, #8
40+
; CHECK-NEXT: add w8, w8, w0
4441
; CHECK-NEXT: csel w0, wzr, w8, lo
4542
; CHECK-NEXT: ret
4643
%cmp = icmp ule i32 %x1, 7
@@ -53,10 +50,9 @@ define i32 @test_ule7_sub_add_i32(i32 %x0, i32 %x1) {
5350
define i32 @test_ule0_sub_add_i32(i32 %x0, i32 %x1) {
5451
; CHECK-LABEL: test_ule0_sub_add_i32:
5552
; CHECK: // %bb.0:
56-
; CHECK-NEXT: add w8, w0, w1
57-
; CHECK-NEXT: cmp w1, #0
58-
; CHECK-NEXT: sub w8, w8, #1
59-
; CHECK-NEXT: csel w0, wzr, w8, eq
53+
; CHECK-NEXT: subs w8, w1, #1
54+
; CHECK-NEXT: add w8, w8, w0
55+
; CHECK-NEXT: csel w0, wzr, w8, lo
6056
; CHECK-NEXT: ret
6157
%cmp = icmp ule i32 %x1, 0
6258
%add = add i32 %x0, %x1
@@ -68,9 +64,8 @@ define i32 @test_ule0_sub_add_i32(i32 %x0, i32 %x1) {
6864
define i32 @test_ultminus2_sub_add_i32(i32 %x0, i32 %x1) {
6965
; CHECK-LABEL: test_ultminus2_sub_add_i32:
7066
; CHECK: // %bb.0:
71-
; CHECK-NEXT: add w8, w0, w1
72-
; CHECK-NEXT: cmn w1, #2
73-
; CHECK-NEXT: add w8, w8, #2
67+
; CHECK-NEXT: adds w8, w1, #2
68+
; CHECK-NEXT: add w8, w8, w0
7469
; CHECK-NEXT: csel w0, wzr, w8, lo
7570
; CHECK-NEXT: ret
7671
%cmp = icmp ult i32 %x1, -2
@@ -83,10 +78,9 @@ define i32 @test_ultminus2_sub_add_i32(i32 %x0, i32 %x1) {
8378
define i32 @test_ne0_sub_add_i32(i32 %x0, i32 %x1) {
8479
; CHECK-LABEL: test_ne0_sub_add_i32:
8580
; CHECK: // %bb.0:
86-
; CHECK-NEXT: add w8, w0, w1
87-
; CHECK-NEXT: cmp w1, #0
88-
; CHECK-NEXT: sub w8, w8, #1
89-
; CHECK-NEXT: csel w0, w8, wzr, ne
81+
; CHECK-NEXT: subs w8, w1, #1
82+
; CHECK-NEXT: add w8, w8, w0
83+
; CHECK-NEXT: csel w0, w8, wzr, hs
9084
; CHECK-NEXT: ret
9185
%cmp = icmp ne i32 %x1, 0
9286
%add = add i32 %x0, %x1
@@ -98,10 +92,9 @@ define i32 @test_ne0_sub_add_i32(i32 %x0, i32 %x1) {
9892
define i32 @test_ugt7_sub_add_i32(i32 %x0, i32 %x1) {
9993
; CHECK-LABEL: test_ugt7_sub_add_i32:
10094
; CHECK: // %bb.0:
101-
; CHECK-NEXT: add w8, w0, w1
102-
; CHECK-NEXT: cmp w1, #7
103-
; CHECK-NEXT: sub w8, w8, #8
104-
; CHECK-NEXT: csel w0, wzr, w8, hi
95+
; CHECK-NEXT: subs w8, w1, #8
96+
; CHECK-NEXT: add w8, w8, w0
97+
; CHECK-NEXT: csel w0, wzr, w8, hs
10598
; CHECK-NEXT: ret
10699
%cmp = icmp ugt i32 %x1, 7
107100
%add = add i32 %x0, %x1
@@ -113,10 +106,9 @@ define i32 @test_ugt7_sub_add_i32(i32 %x0, i32 %x1) {
113106
define i32 @test_eq0_sub_addcomm_i32(i32 %x0, i32 %x1) {
114107
; CHECK-LABEL: test_eq0_sub_addcomm_i32:
115108
; CHECK: // %bb.0:
116-
; CHECK-NEXT: add w8, w1, w0
117-
; CHECK-NEXT: cmp w1, #0
118-
; CHECK-NEXT: sub w8, w8, #1
119-
; CHECK-NEXT: csel w0, wzr, w8, eq
109+
; CHECK-NEXT: subs w8, w1, #1
110+
; CHECK-NEXT: add w8, w8, w0
111+
; CHECK-NEXT: csel w0, wzr, w8, lo
120112
; CHECK-NEXT: ret
121113
%cmp = icmp eq i32 %x1, 0
122114
%add = add i32 %x1, %x0
@@ -128,10 +120,9 @@ define i32 @test_eq0_sub_addcomm_i32(i32 %x0, i32 %x1) {
128120
define i32 @test_eq0_subcomm_add_i32(i32 %x0, i32 %x1) {
129121
; CHECK-LABEL: test_eq0_subcomm_add_i32:
130122
; CHECK: // %bb.0:
131-
; CHECK-NEXT: add w8, w0, w1
132-
; CHECK-NEXT: cmp w1, #0
133-
; CHECK-NEXT: sub w8, w8, #1
134-
; CHECK-NEXT: csel w0, wzr, w8, eq
123+
; CHECK-NEXT: subs w8, w1, #1
124+
; CHECK-NEXT: add w8, w8, w0
125+
; CHECK-NEXT: csel w0, wzr, w8, lo
135126
; CHECK-NEXT: ret
136127
%cmp = icmp eq i32 %x1, 0
137128
%add = add i32 %x0, %x1
@@ -143,21 +134,16 @@ define i32 @test_eq0_subcomm_add_i32(i32 %x0, i32 %x1) {
143134
define i32 @test_eq0_multi_use_sub_i32(i32 %x0, i32 %x1) {
144135
; CHECK-LABEL: test_eq0_multi_use_sub_i32:
145136
; CHECK: // %bb.0:
146-
; CHECK-NEXT: str x30, [sp, #-32]! // 8-byte Folded Spill
147-
; CHECK-NEXT: stp x20, x19, [sp, #16] // 16-byte Folded Spill
148-
; CHECK-NEXT: .cfi_def_cfa_offset 32
137+
; CHECK-NEXT: stp x30, x19, [sp, #-16]! // 16-byte Folded Spill
138+
; CHECK-NEXT: .cfi_def_cfa_offset 16
149139
; CHECK-NEXT: .cfi_offset w19, -8
150-
; CHECK-NEXT: .cfi_offset w20, -16
151-
; CHECK-NEXT: .cfi_offset w30, -32
152-
; CHECK-NEXT: add w8, w0, w1
153-
; CHECK-NEXT: mov w19, w1
154-
; CHECK-NEXT: sub w20, w8, #1
155-
; CHECK-NEXT: mov w0, w20
140+
; CHECK-NEXT: .cfi_offset w30, -16
141+
; CHECK-NEXT: subs w8, w1, #1
142+
; CHECK-NEXT: add w0, w8, w0
143+
; CHECK-NEXT: csel w19, wzr, w0, lo
156144
; CHECK-NEXT: bl use_i32
157-
; CHECK-NEXT: cmp w19, #0
158-
; CHECK-NEXT: csel w0, wzr, w20, eq
159-
; CHECK-NEXT: ldp x20, x19, [sp, #16] // 16-byte Folded Reload
160-
; CHECK-NEXT: ldr x30, [sp], #32 // 8-byte Folded Reload
145+
; CHECK-NEXT: mov w0, w19
146+
; CHECK-NEXT: ldp x30, x19, [sp], #16 // 16-byte Folded Reload
161147
; CHECK-NEXT: ret
162148
%cmp = icmp eq i32 %x1, 0
163149
%add = add nuw i32 %x0, %x1

0 commit comments

Comments
 (0)