Skip to content
This repository was archived by the owner on Feb 5, 2019. It is now read-only.

Commit 298d1a6

Browse files
committed
[DAG] Teach DAG to also reassociate vector operations
This commit teaches DAG to reassociate vector ops, which in turn enables constant folding of vector op chains that appear later on during custom lowering and DAG combine. Reviewed by Andrea Di Biagio git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@199135 91177308-0d34-0410-b5e6-96231b3b80d8
1 parent de0847d commit 298d1a6

File tree

4 files changed

+82
-27
lines changed

4 files changed

+82
-27
lines changed

include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,6 +1492,8 @@ class BuildVectorSDNode : public SDNode {
14921492
unsigned &SplatBitSize, bool &HasAnyUndefs,
14931493
unsigned MinSplatBits = 0, bool isBigEndian = false);
14941494

1495+
bool isConstant() const;
1496+
14951497
static inline bool classof(const SDNode *N) {
14961498
return N->getOpcode() == ISD::BUILD_VECTOR;
14971499
}

lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,51 @@ static bool isOneUseSetCC(SDValue N) {
610610
SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL,
611611
SDValue N0, SDValue N1) {
612612
EVT VT = N0.getValueType();
613+
if (VT.isVector()) {
614+
if (N0.getOpcode() == Opc) {
615+
BuildVectorSDNode *L = dyn_cast<BuildVectorSDNode>(N0.getOperand(1));
616+
if(L && L->isConstant()) {
617+
BuildVectorSDNode *R = dyn_cast<BuildVectorSDNode>(N1);
618+
if (R && R->isConstant()) {
619+
// reassoc. (op (op x, c1), c2) -> (op x, (op c1, c2))
620+
SDValue OpNode = DAG.FoldConstantArithmetic(Opc, VT, L, R);
621+
return DAG.getNode(Opc, DL, VT, N0.getOperand(0), OpNode);
622+
}
623+
624+
if (N0.hasOneUse()) {
625+
// reassoc. (op (op x, c1), y) -> (op (op x, y), c1) iff x+c1 has one
626+
// use
627+
SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT,
628+
N0.getOperand(0), N1);
629+
AddToWorkList(OpNode.getNode());
630+
return DAG.getNode(Opc, DL, VT, OpNode, N0.getOperand(1));
631+
}
632+
}
633+
}
634+
635+
if (N1.getOpcode() == Opc) {
636+
BuildVectorSDNode *R = dyn_cast<BuildVectorSDNode>(N1.getOperand(1));
637+
if (R && R->isConstant()) {
638+
BuildVectorSDNode *L = dyn_cast<BuildVectorSDNode>(N0);
639+
if (L && L->isConstant()) {
640+
// reassoc. (op c2, (op x, c1)) -> (op x, (op c1, c2))
641+
SDValue OpNode = DAG.FoldConstantArithmetic(Opc, VT, R, L);
642+
return DAG.getNode(Opc, DL, VT, N1.getOperand(0), OpNode);
643+
}
644+
if (N1.hasOneUse()) {
645+
// reassoc. (op y, (op x, c1)) -> (op (op x, y), c1) iff x+c1 has one
646+
// use
647+
SDValue OpNode = DAG.getNode(Opc, SDLoc(N0), VT,
648+
N1.getOperand(0), N0);
649+
AddToWorkList(OpNode.getNode());
650+
return DAG.getNode(Opc, DL, VT, OpNode, N1.getOperand(1));
651+
}
652+
}
653+
}
654+
655+
return SDValue();
656+
}
657+
613658
if (N0.getOpcode() == Opc && isa<ConstantSDNode>(N0.getOperand(1))) {
614659
if (isa<ConstantSDNode>(N1)) {
615660
// reassoc. (op (op x, c1), c2) -> (op x, (op c1, c2))
@@ -5868,14 +5913,7 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
58685913
if (!LegalTypes &&
58695914
N0.getOpcode() == ISD::BUILD_VECTOR && N0.getNode()->hasOneUse() &&
58705915
VT.isVector()) {
5871-
bool isSimple = true;
5872-
for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i)
5873-
if (N0.getOperand(i).getOpcode() != ISD::UNDEF &&
5874-
N0.getOperand(i).getOpcode() != ISD::Constant &&
5875-
N0.getOperand(i).getOpcode() != ISD::ConstantFP) {
5876-
isSimple = false;
5877-
break;
5878-
}
5916+
bool isSimple = cast<BuildVectorSDNode>(N0)->isConstant();
58795917

58805918
EVT DestEltVT = N->getValueType(0).getVectorElementType();
58815919
assert(!DestEltVT.isVector() &&
@@ -10381,18 +10419,15 @@ SDValue DAGCombiner::SimplifyVBinOp(SDNode *N) {
1038110419
// this operation.
1038210420
if (LHS.getOpcode() == ISD::BUILD_VECTOR &&
1038310421
RHS.getOpcode() == ISD::BUILD_VECTOR) {
10422+
// Check if both vectors are constants. If not bail out.
10423+
if (!cast<BuildVectorSDNode>(LHS)->isConstant() &&
10424+
!cast<BuildVectorSDNode>(RHS)->isConstant())
10425+
return SDValue();
10426+
1038410427
SmallVector<SDValue, 8> Ops;
1038510428
for (unsigned i = 0, e = LHS.getNumOperands(); i != e; ++i) {
1038610429
SDValue LHSOp = LHS.getOperand(i);
1038710430
SDValue RHSOp = RHS.getOperand(i);
10388-
// If these two elements can't be folded, bail out.
10389-
if ((LHSOp.getOpcode() != ISD::UNDEF &&
10390-
LHSOp.getOpcode() != ISD::Constant &&
10391-
LHSOp.getOpcode() != ISD::ConstantFP) ||
10392-
(RHSOp.getOpcode() != ISD::UNDEF &&
10393-
RHSOp.getOpcode() != ISD::Constant &&
10394-
RHSOp.getOpcode() != ISD::ConstantFP))
10395-
break;
1039610431

1039710432
// Can't fold divide by zero.
1039810433
if (N->getOpcode() == ISD::SDIV || N->getOpcode() == ISD::UDIV ||

lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6533,6 +6533,15 @@ bool BuildVectorSDNode::isConstantSplat(APInt &SplatValue,
65336533
return true;
65346534
}
65356535

6536+
bool BuildVectorSDNode::isConstant() const {
6537+
for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
6538+
unsigned Opc = getOperand(i).getOpcode();
6539+
if (Opc != ISD::UNDEF && Opc != ISD::Constant && Opc != ISD::ConstantFP)
6540+
return false;
6541+
}
6542+
return true;
6543+
}
6544+
65366545
bool ShuffleVectorSDNode::isSplatMask(const int *Mask, EVT VT) {
65376546
// Find the first non-undef value in the shuffle mask.
65386547
unsigned i, e;

test/CodeGen/X86/vector-gep.ll

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,26 @@
44
;CHECK-LABEL: AGEP0:
55
define <4 x i32*> @AGEP0(i32* %ptr) nounwind {
66
entry:
7+
;CHECK-LABEL: AGEP0
8+
;CHECK: vbroadcast
9+
;CHECK-NEXT: vpaddd
10+
;CHECK-NEXT: ret
711
%vecinit.i = insertelement <4 x i32*> undef, i32* %ptr, i32 0
812
%vecinit2.i = insertelement <4 x i32*> %vecinit.i, i32* %ptr, i32 1
913
%vecinit4.i = insertelement <4 x i32*> %vecinit2.i, i32* %ptr, i32 2
1014
%vecinit6.i = insertelement <4 x i32*> %vecinit4.i, i32* %ptr, i32 3
11-
;CHECK: padd
1215
%A2 = getelementptr <4 x i32*> %vecinit6.i, <4 x i32> <i32 1, i32 2, i32 3, i32 4>
13-
;CHECK: padd
1416
%A3 = getelementptr <4 x i32*> %A2, <4 x i32> <i32 10, i32 14, i32 19, i32 233>
1517
ret <4 x i32*> %A3
16-
;CHECK: ret
1718
}
1819

1920
;CHECK-LABEL: AGEP1:
2021
define i32 @AGEP1(<4 x i32*> %param) nounwind {
2122
entry:
22-
;CHECK: padd
23+
;CHECK-LABEL: AGEP1
24+
;CHECK: vpaddd
25+
;CHECK-NEXT: vpextrd
26+
;CHECK-NEXT: movl
2327
%A2 = getelementptr <4 x i32*> %param, <4 x i32> <i32 1, i32 2, i32 3, i32 4>
2428
%k = extractelement <4 x i32*> %A2, i32 3
2529
%v = load i32* %k
@@ -30,8 +34,9 @@ entry:
3034
;CHECK-LABEL: AGEP2:
3135
define i32 @AGEP2(<4 x i32*> %param, <4 x i32> %off) nounwind {
3236
entry:
33-
;CHECK: pslld $2
34-
;CHECK: padd
37+
;CHECK_LABEL: AGEP2
38+
;CHECK: vpslld $2
39+
;CHECK-NEXT: vpadd
3540
%A2 = getelementptr <4 x i32*> %param, <4 x i32> %off
3641
%k = extractelement <4 x i32*> %A2, i32 3
3742
%v = load i32* %k
@@ -42,8 +47,9 @@ entry:
4247
;CHECK-LABEL: AGEP3:
4348
define <4 x i32*> @AGEP3(<4 x i32*> %param, <4 x i32> %off) nounwind {
4449
entry:
45-
;CHECK: pslld $2
46-
;CHECK: padd
50+
;CHECK-LABEL: AGEP3
51+
;CHECK: vpslld $2
52+
;CHECK-NEXT: vpadd
4753
%A2 = getelementptr <4 x i32*> %param, <4 x i32> %off
4854
%v = alloca i32
4955
%k = insertelement <4 x i32*> %A2, i32* %v, i32 3
@@ -54,10 +60,11 @@ entry:
5460
;CHECK-LABEL: AGEP4:
5561
define <4 x i16*> @AGEP4(<4 x i16*> %param, <4 x i32> %off) nounwind {
5662
entry:
63+
;CHECK-LABEL: AGEP4
5764
; Multiply offset by two (add it to itself).
58-
;CHECK: padd
65+
;CHECK: vpadd
5966
; add the base to the offset
60-
;CHECK: padd
67+
;CHECK-NEXT: vpadd
6168
%A = getelementptr <4 x i16*> %param, <4 x i32> %off
6269
ret <4 x i16*> %A
6370
;CHECK: ret
@@ -66,7 +73,8 @@ entry:
6673
;CHECK-LABEL: AGEP5:
6774
define <4 x i8*> @AGEP5(<4 x i8*> %param, <4 x i8> %off) nounwind {
6875
entry:
69-
;CHECK: paddd
76+
;CHECK-LABEL: AGEP5
77+
;CHECK: vpaddd
7078
%A = getelementptr <4 x i8*> %param, <4 x i8> %off
7179
ret <4 x i8*> %A
7280
;CHECK: ret
@@ -77,6 +85,7 @@ entry:
7785
;CHECK-LABEL: AGEP6:
7886
define <4 x i8*> @AGEP6(<4 x i8*> %param, <4 x i32> %off) nounwind {
7987
entry:
88+
;CHECK-LABEL: AGEP6
8089
;CHECK-NOT: pslld
8190
%A = getelementptr <4 x i8*> %param, <4 x i32> %off
8291
ret <4 x i8*> %A

0 commit comments

Comments
 (0)