Skip to content

Commit 69a2115

Browse files
authored
[DAG] Fold trunc(srl(extract_elt(vec,c1),c2)) -> extract_elt(bitcast(vec),c3) (#107987)
Extends existing trunc(extract_elt(vec,c1)) -> extract_elt(bitcast(vec),c3) fold. Noticed while working on #107404
1 parent 387bee9 commit 69a2115

File tree

2 files changed

+41
-29
lines changed

2 files changed

+41
-29
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15142,26 +15142,42 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1514215142
// Note: We only run this optimization after type legalization (which often
1514315143
// creates this pattern) and before operation legalization after which
1514415144
// we need to be more careful about the vector instructions that we generate.
15145-
if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
15146-
LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
15147-
EVT VecTy = N0.getOperand(0).getValueType();
15148-
EVT ExTy = N0.getValueType();
15145+
if (LegalTypes && !LegalOperations && VT.isScalarInteger() && VT != MVT::i1 &&
15146+
N0->hasOneUse()) {
1514915147
EVT TrTy = N->getValueType(0);
15148+
SDValue Src = N0;
15149+
15150+
// Check for cases where we shift down an upper element before truncation.
15151+
int EltOffset = 0;
15152+
if (Src.getOpcode() == ISD::SRL && Src.getOperand(0)->hasOneUse()) {
15153+
if (auto ShAmt = DAG.getValidShiftAmount(Src)) {
15154+
if ((*ShAmt % TrTy.getSizeInBits()) == 0) {
15155+
Src = Src.getOperand(0);
15156+
EltOffset = *ShAmt / TrTy.getSizeInBits();
15157+
}
15158+
}
15159+
}
1515015160

15151-
auto EltCnt = VecTy.getVectorElementCount();
15152-
unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
15153-
auto NewEltCnt = EltCnt * SizeRatio;
15161+
if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
15162+
EVT VecTy = Src.getOperand(0).getValueType();
15163+
EVT ExTy = Src.getValueType();
1515415164

15155-
EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
15156-
assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
15165+
auto EltCnt = VecTy.getVectorElementCount();
15166+
unsigned SizeRatio = ExTy.getSizeInBits() / TrTy.getSizeInBits();
15167+
auto NewEltCnt = EltCnt * SizeRatio;
1515715168

15158-
SDValue EltNo = N0->getOperand(1);
15159-
if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
15160-
int Elt = EltNo->getAsZExtVal();
15161-
int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
15162-
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
15163-
DAG.getBitcast(NVT, N0.getOperand(0)),
15164-
DAG.getVectorIdxConstant(Index, DL));
15169+
EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
15170+
assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
15171+
15172+
SDValue EltNo = Src->getOperand(1);
15173+
if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
15174+
int Elt = EltNo->getAsZExtVal();
15175+
int Index = isLE ? (Elt * SizeRatio + EltOffset)
15176+
: (Elt * SizeRatio + (SizeRatio - 1) - EltOffset);
15177+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
15178+
DAG.getBitcast(NVT, Src.getOperand(0)),
15179+
DAG.getVectorIdxConstant(Index, DL));
15180+
}
1516515181
}
1516615182
}
1516715183

llvm/test/CodeGen/AArch64/expand-select.ll

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,20 @@ define void @bar(i32 %In1, <2 x i96> %In2, <2 x i96> %In3, ptr %Out) {
3333
; CHECK: // %bb.0:
3434
; CHECK-NEXT: and w8, w0, #0x1
3535
; CHECK-NEXT: fmov s0, wzr
36-
; CHECK-NEXT: ldr x11, [sp, #16]
36+
; CHECK-NEXT: ldr x10, [sp, #16]
3737
; CHECK-NEXT: fmov s1, w8
38-
; CHECK-NEXT: ldp x9, x10, [sp]
3938
; CHECK-NEXT: cmeq v0.4s, v1.4s, v0.4s
40-
; CHECK-NEXT: dup v1.4s, v0.s[0]
41-
; CHECK-NEXT: mov x8, v1.d[1]
42-
; CHECK-NEXT: lsr x8, x8, #32
43-
; CHECK-NEXT: tst w8, #0x1
4439
; CHECK-NEXT: fmov w8, s0
45-
; CHECK-NEXT: csel x10, x5, x10, ne
46-
; CHECK-NEXT: csel x9, x4, x9, ne
47-
; CHECK-NEXT: stur x9, [x11, #12]
4840
; CHECK-NEXT: tst w8, #0x1
49-
; CHECK-NEXT: str w10, [x11, #20]
50-
; CHECK-NEXT: csel x8, x2, x6, ne
41+
; CHECK-NEXT: ldp x9, x8, [sp]
42+
; CHECK-NEXT: csel x11, x2, x6, ne
43+
; CHECK-NEXT: str x11, [x10]
44+
; CHECK-NEXT: csel x9, x4, x9, ne
45+
; CHECK-NEXT: csel x8, x5, x8, ne
46+
; CHECK-NEXT: stur x9, [x10, #12]
5147
; CHECK-NEXT: csel x9, x3, x7, ne
52-
; CHECK-NEXT: str x8, [x11]
53-
; CHECK-NEXT: str w9, [x11, #8]
48+
; CHECK-NEXT: str w8, [x10, #20]
49+
; CHECK-NEXT: str w9, [x10, #8]
5450
; CHECK-NEXT: ret
5551
%cond = and i32 %In1, 1
5652
%cbool = icmp eq i32 %cond, 0

0 commit comments

Comments
 (0)