Skip to content

Commit 6a8d8f3

Browse files
committed
[AArch64][DAGCombiner]: combine <2xi64> add/sub.
64-bit vector mul is not supported in NEON, so we use the SVE's mul. To improve the performance, we can go one step further, and use SVE's add/sub, so that we can use SVE's mla/mls. That works on these patterns: // This works on the patterns of: // add v1, (mul v2, v3) // sub v1, (mul v2, v3) Reviewed By: david-arm Differential Revision: https://reviews.llvm.org/D147236
1 parent 0e37487 commit 6a8d8f3

File tree

3 files changed

+171
-58
lines changed

3 files changed

+171
-58
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17804,9 +17804,64 @@ static SDValue performSubAddMULCombine(SDNode *N, SelectionDAG &DAG) {
1780417804
return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2);
1780517805
}
1780617806

17807+
// This works on the patterns of:
17808+
// add v1, (mul v2, v3)
17809+
// sub v1, (mul v2, v3)
17810+
// for vectors of type <1 x i64> and <2 x i64> when SVE is available.
17811+
// It will transform the add/sub to a scalable version, so that we can
17812+
// make use of SVE's MLA/MLS that will be generated for that pattern
17813+
static SDValue performMulAddSubCombine(SDNode *N, SelectionDAG &DAG) {
17814+
// Before using SVE's features, check first if it's available.
17815+
if (!DAG.getSubtarget<AArch64Subtarget>().hasSVE())
17816+
return SDValue();
17817+
17818+
if (N->getOpcode() != ISD::ADD && N->getOpcode() != ISD::SUB)
17819+
return SDValue();
17820+
17821+
if (!N->getValueType(0).isFixedLengthVector())
17822+
return SDValue();
17823+
17824+
SDValue MulValue, Op, ExtractIndexValue, ExtractOp;
17825+
17826+
if (N->getOperand(0)->getOpcode() == ISD::EXTRACT_SUBVECTOR) {
17827+
ExtractOp = N->getOperand(0);
17828+
Op = N->getOperand(1);
17829+
} else if (N->getOperand(1)->getOpcode() == ISD::EXTRACT_SUBVECTOR) {
17830+
ExtractOp = N->getOperand(1);
17831+
Op = N->getOperand(0);
17832+
} else
17833+
return SDValue();
17834+
17835+
MulValue = ExtractOp.getOperand(0);
17836+
ExtractIndexValue = ExtractOp.getOperand(1);
17837+
17838+
if (!ExtractOp.hasOneUse() && !MulValue.hasOneUse())
17839+
return SDValue();
17840+
17841+
// If the Opcode is NOT MUL, then that is NOT the expected pattern:
17842+
if (MulValue.getOpcode() != AArch64ISD::MUL_PRED)
17843+
return SDValue();
17844+
17845+
// If the Mul value type is NOT scalable vector, then that is NOT the expected
17846+
// pattern:
17847+
EVT VT = MulValue.getValueType();
17848+
if (!VT.isScalableVector())
17849+
return SDValue();
17850+
17851+
// If the ConstValue is NOT 0, then that is NOT the expected pattern:
17852+
if (!cast<ConstantSDNode>(ExtractIndexValue)->isZero())
17853+
return SDValue();
17854+
17855+
SDValue ScaledOp = convertToScalableVector(DAG, VT, Op);
17856+
SDValue NewValue = DAG.getNode(N->getOpcode(), SDLoc(N), VT, {ScaledOp, MulValue});
17857+
return convertFromScalableVector(DAG, N->getValueType(0), NewValue);
17858+
}
17859+
1780717860
static SDValue performAddSubCombine(SDNode *N,
1780817861
TargetLowering::DAGCombinerInfo &DCI,
1780917862
SelectionDAG &DAG) {
17863+
if (SDValue Val = performMulAddSubCombine(N, DAG))
17864+
return Val;
1781017865
// Try to change sum of two reductions.
1781117866
if (SDValue Val = performAddUADDVCombine(N, DAG))
1781217867
return Val;
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
2+
; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mattr=+sve | FileCheck %s
3+
4+
define <2 x i64> @test_mul_add_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) {
5+
; CHECK-LABEL: test_mul_add_2x64:
6+
; CHECK: // %bb.0:
7+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
8+
; CHECK-NEXT: ptrue p0.d, vl2
9+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
10+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
11+
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
12+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
13+
; CHECK-NEXT: ret
14+
%mul = mul <2 x i64> %b, %c
15+
%add = add <2 x i64> %a, %mul
16+
ret <2 x i64> %add
17+
}
18+
19+
define <1 x i64> @test_mul_add_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) {
20+
; CHECK-LABEL: test_mul_add_1x64:
21+
; CHECK: // %bb.0:
22+
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
23+
; CHECK-NEXT: ptrue p0.d, vl1
24+
; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2
25+
; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
26+
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
27+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
28+
; CHECK-NEXT: ret
29+
%mul = mul <1 x i64> %b, %c
30+
%add = add <1 x i64> %mul, %a
31+
ret <1 x i64> %add
32+
}
33+
34+
define <2 x i64> @test_mul_sub_2x64(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c) {
35+
; CHECK-LABEL: test_mul_sub_2x64:
36+
; CHECK: // %bb.0:
37+
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
38+
; CHECK-NEXT: ptrue p0.d, vl2
39+
; CHECK-NEXT: // kill: def $q2 killed $q2 def $z2
40+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
41+
; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d
42+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
43+
; CHECK-NEXT: ret
44+
%mul = mul <2 x i64> %b, %c
45+
%sub = sub <2 x i64> %a, %mul
46+
ret <2 x i64> %sub
47+
}
48+
49+
define <1 x i64> @test_mul_sub_1x64(<1 x i64> %a, <1 x i64> %b, <1 x i64> %c) {
50+
; CHECK-LABEL: test_mul_sub_1x64:
51+
; CHECK: // %bb.0:
52+
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
53+
; CHECK-NEXT: ptrue p0.d, vl1
54+
; CHECK-NEXT: // kill: def $d2 killed $d2 def $z2
55+
; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
56+
; CHECK-NEXT: mls z0.d, p0/m, z1.d, z2.d
57+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
58+
; CHECK-NEXT: ret
59+
%mul = mul <1 x i64> %b, %c
60+
%sub = sub <1 x i64> %mul, %a
61+
ret <1 x i64> %sub
62+
}

llvm/test/CodeGen/AArch64/sve-fixed-length-int-rem.ll

Lines changed: 54 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ define void @srem_v16i32(ptr %a, ptr %b) #0 {
606606
;
607607
; VBITS_GE_256-LABEL: srem_v16i32:
608608
; VBITS_GE_256: // %bb.0:
609-
; VBITS_GE_256-NEXT: mov x8, #8
609+
; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
610610
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
611611
; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x0, x8, lsl #2]
612612
; VBITS_GE_256-NEXT: ld1w { z1.s }, p0/z, [x0]
@@ -680,13 +680,13 @@ define void @srem_v64i32(ptr %a, ptr %b) vscale_range(16,0) #0 {
680680
define <1 x i64> @srem_v1i64(<1 x i64> %op1, <1 x i64> %op2) vscale_range(1,0) #0 {
681681
; CHECK-LABEL: srem_v1i64:
682682
; CHECK: // %bb.0:
683-
; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
684-
; CHECK-NEXT: ptrue p0.d, vl1
685683
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
684+
; CHECK-NEXT: ptrue p0.d, vl1
685+
; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
686686
; CHECK-NEXT: movprfx z2, z0
687687
; CHECK-NEXT: sdiv z2.d, p0/m, z2.d, z1.d
688-
; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d
689-
; CHECK-NEXT: sub d0, d0, d1
688+
; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d
689+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
690690
; CHECK-NEXT: ret
691691
%res = srem <1 x i64> %op1, %op2
692692
ret <1 x i64> %res
@@ -697,13 +697,13 @@ define <1 x i64> @srem_v1i64(<1 x i64> %op1, <1 x i64> %op2) vscale_range(1,0) #
697697
define <2 x i64> @srem_v2i64(<2 x i64> %op1, <2 x i64> %op2) vscale_range(1,0) #0 {
698698
; CHECK-LABEL: srem_v2i64:
699699
; CHECK: // %bb.0:
700-
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
701-
; CHECK-NEXT: ptrue p0.d, vl2
702700
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
701+
; CHECK-NEXT: ptrue p0.d, vl2
702+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
703703
; CHECK-NEXT: movprfx z2, z0
704704
; CHECK-NEXT: sdiv z2.d, p0/m, z2.d, z1.d
705-
; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d
706-
; CHECK-NEXT: sub v0.2d, v0.2d, v1.2d
705+
; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d
706+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
707707
; CHECK-NEXT: ret
708708
%res = srem <2 x i64> %op1, %op2
709709
ret <2 x i64> %res
@@ -730,34 +730,32 @@ define void @srem_v4i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
730730
define void @srem_v8i64(ptr %a, ptr %b) #0 {
731731
; VBITS_GE_128-LABEL: srem_v8i64:
732732
; VBITS_GE_128: // %bb.0:
733-
; VBITS_GE_128-NEXT: ldp q4, q5, [x1]
734-
; VBITS_GE_128-NEXT: ptrue p0.d, vl2
735-
; VBITS_GE_128-NEXT: ldp q7, q6, [x1, #32]
736733
; VBITS_GE_128-NEXT: ldp q0, q1, [x0, #32]
737-
; VBITS_GE_128-NEXT: ldp q2, q3, [x0]
738-
; VBITS_GE_128-NEXT: movprfx z16, z3
739-
; VBITS_GE_128-NEXT: sdiv z16.d, p0/m, z16.d, z5.d
740-
; VBITS_GE_128-NEXT: movprfx z17, z2
741-
; VBITS_GE_128-NEXT: sdiv z17.d, p0/m, z17.d, z4.d
742-
; VBITS_GE_128-NEXT: mul z5.d, p0/m, z5.d, z16.d
734+
; VBITS_GE_128-NEXT: ptrue p0.d, vl2
735+
; VBITS_GE_128-NEXT: ldp q2, q3, [x1, #32]
743736
; VBITS_GE_128-NEXT: movprfx z16, z1
737+
; VBITS_GE_128-NEXT: sdiv z16.d, p0/m, z16.d, z3.d
738+
; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z3.d
739+
; VBITS_GE_128-NEXT: movprfx z3, z0
740+
; VBITS_GE_128-NEXT: sdiv z3.d, p0/m, z3.d, z2.d
741+
; VBITS_GE_128-NEXT: mls z0.d, p0/m, z3.d, z2.d
742+
; VBITS_GE_128-NEXT: ldp q4, q5, [x0]
743+
; VBITS_GE_128-NEXT: ldp q7, q6, [x1]
744+
; VBITS_GE_128-NEXT: movprfx z16, z5
744745
; VBITS_GE_128-NEXT: sdiv z16.d, p0/m, z16.d, z6.d
745-
; VBITS_GE_128-NEXT: mul z4.d, p0/m, z4.d, z17.d
746-
; VBITS_GE_128-NEXT: movprfx z17, z0
747-
; VBITS_GE_128-NEXT: sdiv z17.d, p0/m, z17.d, z7.d
748-
; VBITS_GE_128-NEXT: mul z6.d, p0/m, z6.d, z16.d
749-
; VBITS_GE_128-NEXT: mul z7.d, p0/m, z7.d, z17.d
750-
; VBITS_GE_128-NEXT: sub v0.2d, v0.2d, v7.2d
751-
; VBITS_GE_128-NEXT: sub v1.2d, v1.2d, v6.2d
752-
; VBITS_GE_128-NEXT: sub v2.2d, v2.2d, v4.2d
746+
; VBITS_GE_128-NEXT: movprfx z2, z4
747+
; VBITS_GE_128-NEXT: sdiv z2.d, p0/m, z2.d, z7.d
753748
; VBITS_GE_128-NEXT: stp q0, q1, [x0, #32]
754-
; VBITS_GE_128-NEXT: sub v0.2d, v3.2d, v5.2d
755-
; VBITS_GE_128-NEXT: stp q2, q0, [x0]
749+
; VBITS_GE_128-NEXT: movprfx z0, z4
750+
; VBITS_GE_128-NEXT: mls z0.d, p0/m, z2.d, z7.d
751+
; VBITS_GE_128-NEXT: movprfx z1, z5
752+
; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z6.d
753+
; VBITS_GE_128-NEXT: stp q0, q1, [x0]
756754
; VBITS_GE_128-NEXT: ret
757755
;
758756
; VBITS_GE_256-LABEL: srem_v8i64:
759757
; VBITS_GE_256: // %bb.0:
760-
; VBITS_GE_256-NEXT: mov x8, #4
758+
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
761759
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
762760
; VBITS_GE_256-NEXT: ld1d { z0.d }, p0/z, [x0, x8, lsl #3]
763761
; VBITS_GE_256-NEXT: ld1d { z1.d }, p0/z, [x0]
@@ -1426,7 +1424,7 @@ define void @urem_v16i32(ptr %a, ptr %b) #0 {
14261424
;
14271425
; VBITS_GE_256-LABEL: urem_v16i32:
14281426
; VBITS_GE_256: // %bb.0:
1429-
; VBITS_GE_256-NEXT: mov x8, #8
1427+
; VBITS_GE_256-NEXT: mov x8, #8 // =0x8
14301428
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
14311429
; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x0, x8, lsl #2]
14321430
; VBITS_GE_256-NEXT: ld1w { z1.s }, p0/z, [x0]
@@ -1500,13 +1498,13 @@ define void @urem_v64i32(ptr %a, ptr %b) vscale_range(16,0) #0 {
15001498
define <1 x i64> @urem_v1i64(<1 x i64> %op1, <1 x i64> %op2) vscale_range(1,0) #0 {
15011499
; CHECK-LABEL: urem_v1i64:
15021500
; CHECK: // %bb.0:
1503-
; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
1504-
; CHECK-NEXT: ptrue p0.d, vl1
15051501
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
1502+
; CHECK-NEXT: ptrue p0.d, vl1
1503+
; CHECK-NEXT: // kill: def $d1 killed $d1 def $z1
15061504
; CHECK-NEXT: movprfx z2, z0
15071505
; CHECK-NEXT: udiv z2.d, p0/m, z2.d, z1.d
1508-
; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d
1509-
; CHECK-NEXT: sub d0, d0, d1
1506+
; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d
1507+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
15101508
; CHECK-NEXT: ret
15111509
%res = urem <1 x i64> %op1, %op2
15121510
ret <1 x i64> %res
@@ -1517,13 +1515,13 @@ define <1 x i64> @urem_v1i64(<1 x i64> %op1, <1 x i64> %op2) vscale_range(1,0) #
15171515
define <2 x i64> @urem_v2i64(<2 x i64> %op1, <2 x i64> %op2) vscale_range(1,0) #0 {
15181516
; CHECK-LABEL: urem_v2i64:
15191517
; CHECK: // %bb.0:
1520-
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
1521-
; CHECK-NEXT: ptrue p0.d, vl2
15221518
; CHECK-NEXT: // kill: def $q0 killed $q0 def $z0
1519+
; CHECK-NEXT: ptrue p0.d, vl2
1520+
; CHECK-NEXT: // kill: def $q1 killed $q1 def $z1
15231521
; CHECK-NEXT: movprfx z2, z0
15241522
; CHECK-NEXT: udiv z2.d, p0/m, z2.d, z1.d
1525-
; CHECK-NEXT: mul z1.d, p0/m, z1.d, z2.d
1526-
; CHECK-NEXT: sub v0.2d, v0.2d, v1.2d
1523+
; CHECK-NEXT: mls z0.d, p0/m, z2.d, z1.d
1524+
; CHECK-NEXT: // kill: def $q0 killed $q0 killed $z0
15271525
; CHECK-NEXT: ret
15281526
%res = urem <2 x i64> %op1, %op2
15291527
ret <2 x i64> %res
@@ -1550,34 +1548,32 @@ define void @urem_v4i64(ptr %a, ptr %b) vscale_range(2,0) #0 {
15501548
define void @urem_v8i64(ptr %a, ptr %b) #0 {
15511549
; VBITS_GE_128-LABEL: urem_v8i64:
15521550
; VBITS_GE_128: // %bb.0:
1553-
; VBITS_GE_128-NEXT: ldp q4, q5, [x1]
1554-
; VBITS_GE_128-NEXT: ptrue p0.d, vl2
1555-
; VBITS_GE_128-NEXT: ldp q7, q6, [x1, #32]
15561551
; VBITS_GE_128-NEXT: ldp q0, q1, [x0, #32]
1557-
; VBITS_GE_128-NEXT: ldp q2, q3, [x0]
1558-
; VBITS_GE_128-NEXT: movprfx z16, z3
1559-
; VBITS_GE_128-NEXT: udiv z16.d, p0/m, z16.d, z5.d
1560-
; VBITS_GE_128-NEXT: movprfx z17, z2
1561-
; VBITS_GE_128-NEXT: udiv z17.d, p0/m, z17.d, z4.d
1562-
; VBITS_GE_128-NEXT: mul z5.d, p0/m, z5.d, z16.d
1552+
; VBITS_GE_128-NEXT: ptrue p0.d, vl2
1553+
; VBITS_GE_128-NEXT: ldp q2, q3, [x1, #32]
15631554
; VBITS_GE_128-NEXT: movprfx z16, z1
1555+
; VBITS_GE_128-NEXT: udiv z16.d, p0/m, z16.d, z3.d
1556+
; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z3.d
1557+
; VBITS_GE_128-NEXT: movprfx z3, z0
1558+
; VBITS_GE_128-NEXT: udiv z3.d, p0/m, z3.d, z2.d
1559+
; VBITS_GE_128-NEXT: mls z0.d, p0/m, z3.d, z2.d
1560+
; VBITS_GE_128-NEXT: ldp q4, q5, [x0]
1561+
; VBITS_GE_128-NEXT: ldp q7, q6, [x1]
1562+
; VBITS_GE_128-NEXT: movprfx z16, z5
15641563
; VBITS_GE_128-NEXT: udiv z16.d, p0/m, z16.d, z6.d
1565-
; VBITS_GE_128-NEXT: mul z4.d, p0/m, z4.d, z17.d
1566-
; VBITS_GE_128-NEXT: movprfx z17, z0
1567-
; VBITS_GE_128-NEXT: udiv z17.d, p0/m, z17.d, z7.d
1568-
; VBITS_GE_128-NEXT: mul z6.d, p0/m, z6.d, z16.d
1569-
; VBITS_GE_128-NEXT: mul z7.d, p0/m, z7.d, z17.d
1570-
; VBITS_GE_128-NEXT: sub v0.2d, v0.2d, v7.2d
1571-
; VBITS_GE_128-NEXT: sub v1.2d, v1.2d, v6.2d
1572-
; VBITS_GE_128-NEXT: sub v2.2d, v2.2d, v4.2d
1564+
; VBITS_GE_128-NEXT: movprfx z2, z4
1565+
; VBITS_GE_128-NEXT: udiv z2.d, p0/m, z2.d, z7.d
15731566
; VBITS_GE_128-NEXT: stp q0, q1, [x0, #32]
1574-
; VBITS_GE_128-NEXT: sub v0.2d, v3.2d, v5.2d
1575-
; VBITS_GE_128-NEXT: stp q2, q0, [x0]
1567+
; VBITS_GE_128-NEXT: movprfx z0, z4
1568+
; VBITS_GE_128-NEXT: mls z0.d, p0/m, z2.d, z7.d
1569+
; VBITS_GE_128-NEXT: movprfx z1, z5
1570+
; VBITS_GE_128-NEXT: mls z1.d, p0/m, z16.d, z6.d
1571+
; VBITS_GE_128-NEXT: stp q0, q1, [x0]
15761572
; VBITS_GE_128-NEXT: ret
15771573
;
15781574
; VBITS_GE_256-LABEL: urem_v8i64:
15791575
; VBITS_GE_256: // %bb.0:
1580-
; VBITS_GE_256-NEXT: mov x8, #4
1576+
; VBITS_GE_256-NEXT: mov x8, #4 // =0x4
15811577
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
15821578
; VBITS_GE_256-NEXT: ld1d { z0.d }, p0/z, [x0, x8, lsl #3]
15831579
; VBITS_GE_256-NEXT: ld1d { z1.d }, p0/z, [x0]

0 commit comments

Comments
 (0)