Skip to content

Commit 3b98335

Browse files
committed
[AArch64] Alter mull buildvectors(ext(..)) combine to work on shuffles
D120018 altered this combine to work on buildvectors as opposed to shuffle dup's. This works well for dups and other things that are expanded into buildvectors. Some shuffles are legal though, and stay as vector_shuffle through lowering. This expands the transform to also handle shuffles, so that we can turn mul(shuffle(sext into mul(sext(shuffle and more readily make smull/umull instructions. This can come up from the SLP vectorizer adding shuffles that are costed from extends. Differential Revision: https://reviews.llvm.org/D123012
1 parent a70480d commit 3b98335

File tree

2 files changed

+49
-57
lines changed

2 files changed

+49
-57
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13629,15 +13629,17 @@ static EVT calculatePreExtendType(SDValue Extend) {
1362913629
}
1363013630
}
1363113631

13632-
/// Combines a buildvector(sext/zext) node pattern into sext/zext(buildvector)
13633-
/// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt
13634-
static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) {
13632+
/// Combines a buildvector(sext/zext) or shuffle(sext/zext, undef) node pattern
13633+
/// into sext/zext(buildvector) or sext/zext(shuffle) making use of the vector
13634+
/// SExt/ZExt rather than the scalar SExt/ZExt
13635+
static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) {
1363513636
EVT VT = BV.getValueType();
13636-
if (BV.getOpcode() != ISD::BUILD_VECTOR)
13637+
if (BV.getOpcode() != ISD::BUILD_VECTOR &&
13638+
BV.getOpcode() != ISD::VECTOR_SHUFFLE)
1363713639
return SDValue();
1363813640

13639-
// Use the first item in the buildvector to get the size of the extend, and
13640-
// make sure it looks valid.
13641+
// Use the first item in the buildvector/shuffle to get the size of the
13642+
// extend, and make sure it looks valid.
1364113643
SDValue Extend = BV->getOperand(0);
1364213644
unsigned ExtendOpcode = Extend.getOpcode();
1364313645
bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND ||
@@ -13646,31 +13648,49 @@ static SDValue performBuildVectorExtendCombine(SDValue BV, SelectionDAG &DAG) {
1364613648
if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND &&
1364713649
ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND)
1364813650
return SDValue();
13651+
// Shuffle inputs are vector, limit to SIGN_EXTEND and ZERO_EXTEND to ensure
13652+
// calculatePreExtendType will work without issue.
13653+
if (BV.getOpcode() == ISD::VECTOR_SHUFFLE &&
13654+
ExtendOpcode != ISD::SIGN_EXTEND && ExtendOpcode != ISD::ZERO_EXTEND)
13655+
return SDValue();
1364913656

1365013657
// Restrict valid pre-extend data type
1365113658
EVT PreExtendType = calculatePreExtendType(Extend);
1365213659
if (PreExtendType == MVT::Other ||
13653-
PreExtendType.getSizeInBits() != VT.getScalarSizeInBits() / 2)
13660+
PreExtendType.getScalarSizeInBits() != VT.getScalarSizeInBits() / 2)
1365413661
return SDValue();
1365513662

1365613663
// Make sure all other operands are equally extended
1365713664
for (SDValue Op : drop_begin(BV->ops())) {
13665+
if (Op.isUndef())
13666+
continue;
1365813667
unsigned Opc = Op.getOpcode();
1365913668
bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG ||
1366013669
Opc == ISD::AssertSext;
1366113670
if (OpcIsSExt != IsSExt || calculatePreExtendType(Op) != PreExtendType)
1366213671
return SDValue();
1366313672
}
1366413673

13665-
EVT PreExtendVT = VT.changeVectorElementType(PreExtendType);
13666-
EVT PreExtendLegalType =
13667-
PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType;
13674+
SDValue NBV;
1366813675
SDLoc DL(BV);
13669-
SmallVector<SDValue, 8> NewOps;
13670-
for (SDValue Op : BV->ops())
13671-
NewOps.push_back(
13672-
DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, PreExtendLegalType));
13673-
SDValue NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps);
13676+
if (BV.getOpcode() == ISD::BUILD_VECTOR) {
13677+
EVT PreExtendVT = VT.changeVectorElementType(PreExtendType);
13678+
EVT PreExtendLegalType =
13679+
PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType;
13680+
SmallVector<SDValue, 8> NewOps;
13681+
for (SDValue Op : BV->ops())
13682+
NewOps.push_back(Op.isUndef() ? DAG.getUNDEF(PreExtendLegalType)
13683+
: DAG.getAnyExtOrTrunc(Op.getOperand(0), DL,
13684+
PreExtendLegalType));
13685+
NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps);
13686+
} else { // BV.getOpcode() == ISD::VECTOR_SHUFFLE
13687+
EVT PreExtendVT = VT.changeVectorElementType(PreExtendType.getScalarType());
13688+
NBV = DAG.getVectorShuffle(PreExtendVT, DL, BV.getOperand(0).getOperand(0),
13689+
BV.getOperand(1).isUndef()
13690+
? DAG.getUNDEF(PreExtendVT)
13691+
: BV.getOperand(1).getOperand(0),
13692+
cast<ShuffleVectorSDNode>(BV)->getMask());
13693+
}
1367413694
return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV);
1367513695
}
1367613696

@@ -13682,8 +13702,8 @@ static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) {
1368213702
if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64)
1368313703
return SDValue();
1368413704

13685-
SDValue Op0 = performBuildVectorExtendCombine(Mul->getOperand(0), DAG);
13686-
SDValue Op1 = performBuildVectorExtendCombine(Mul->getOperand(1), DAG);
13705+
SDValue Op0 = performBuildShuffleExtendCombine(Mul->getOperand(0), DAG);
13706+
SDValue Op1 = performBuildShuffleExtendCombine(Mul->getOperand(1), DAG);
1368713707

1368813708
// Neither operands have been changed, don't make any further changes
1368913709
if (!Op0 && !Op1)

llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll

Lines changed: 12 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,8 @@ entry:
245245
define <8 x i16> @missing_insert(<8 x i8> %b) {
246246
; CHECK-LABEL: missing_insert:
247247
; CHECK: // %bb.0: // %entry
248-
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
249-
; CHECK-NEXT: ext v1.16b, v0.16b, v0.16b, #4
250-
; CHECK-NEXT: mul v0.8h, v1.8h, v0.8h
248+
; CHECK-NEXT: ext v1.8b, v0.8b, v0.8b, #2
249+
; CHECK-NEXT: smull v0.8h, v1.8b, v0.8b
251250
; CHECK-NEXT: ret
252251
entry:
253252
%ext.b = sext <8 x i8> %b to <8 x i16>
@@ -259,11 +258,8 @@ entry:
259258
define <8 x i16> @shufsext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) {
260259
; CHECK-LABEL: shufsext_v8i8_v8i16:
261260
; CHECK: // %bb.0: // %entry
262-
; CHECK-NEXT: sshll v0.8h, v0.8b, #0
263-
; CHECK-NEXT: sshll v1.8h, v1.8b, #0
264-
; CHECK-NEXT: rev64 v0.8h, v0.8h
265-
; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8
266-
; CHECK-NEXT: mul v0.8h, v0.8h, v1.8h
261+
; CHECK-NEXT: rev64 v0.8b, v0.8b
262+
; CHECK-NEXT: smull v0.8h, v0.8b, v1.8b
267263
; CHECK-NEXT: ret
268264
entry:
269265
%in = sext <8 x i8> %src to <8 x i16>
@@ -276,17 +272,8 @@ entry:
276272
define <2 x i64> @shufsext_v2i32_v2i64(<2 x i32> %src, <2 x i32> %b) {
277273
; CHECK-LABEL: shufsext_v2i32_v2i64:
278274
; CHECK: // %bb.0: // %entry
279-
; CHECK-NEXT: sshll v0.2d, v0.2s, #0
280-
; CHECK-NEXT: sshll v1.2d, v1.2s, #0
281-
; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8
282-
; CHECK-NEXT: fmov x9, d1
283-
; CHECK-NEXT: mov x8, v1.d[1]
284-
; CHECK-NEXT: fmov x10, d0
285-
; CHECK-NEXT: mov x11, v0.d[1]
286-
; CHECK-NEXT: mul x9, x10, x9
287-
; CHECK-NEXT: mul x8, x11, x8
288-
; CHECK-NEXT: fmov d0, x9
289-
; CHECK-NEXT: mov v0.d[1], x8
275+
; CHECK-NEXT: rev64 v0.2s, v0.2s
276+
; CHECK-NEXT: smull v0.2d, v0.2s, v1.2s
290277
; CHECK-NEXT: ret
291278
entry:
292279
%in = sext <2 x i32> %src to <2 x i64>
@@ -299,11 +286,8 @@ entry:
299286
define <8 x i16> @shufzext_v8i8_v8i16(<8 x i8> %src, <8 x i8> %b) {
300287
; CHECK-LABEL: shufzext_v8i8_v8i16:
301288
; CHECK: // %bb.0: // %entry
302-
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
303-
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
304-
; CHECK-NEXT: rev64 v0.8h, v0.8h
305-
; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8
306-
; CHECK-NEXT: mul v0.8h, v0.8h, v1.8h
289+
; CHECK-NEXT: rev64 v0.8b, v0.8b
290+
; CHECK-NEXT: umull v0.8h, v0.8b, v1.8b
307291
; CHECK-NEXT: ret
308292
entry:
309293
%in = zext <8 x i8> %src to <8 x i16>
@@ -316,17 +300,8 @@ entry:
316300
define <2 x i64> @shufzext_v2i32_v2i64(<2 x i32> %src, <2 x i32> %b) {
317301
; CHECK-LABEL: shufzext_v2i32_v2i64:
318302
; CHECK: // %bb.0: // %entry
319-
; CHECK-NEXT: sshll v0.2d, v0.2s, #0
320-
; CHECK-NEXT: sshll v1.2d, v1.2s, #0
321-
; CHECK-NEXT: ext v0.16b, v0.16b, v0.16b, #8
322-
; CHECK-NEXT: fmov x9, d1
323-
; CHECK-NEXT: mov x8, v1.d[1]
324-
; CHECK-NEXT: fmov x10, d0
325-
; CHECK-NEXT: mov x11, v0.d[1]
326-
; CHECK-NEXT: mul x9, x10, x9
327-
; CHECK-NEXT: mul x8, x11, x8
328-
; CHECK-NEXT: fmov d0, x9
329-
; CHECK-NEXT: mov v0.d[1], x8
303+
; CHECK-NEXT: rev64 v0.2s, v0.2s
304+
; CHECK-NEXT: smull v0.2d, v0.2s, v1.2s
330305
; CHECK-NEXT: ret
331306
entry:
332307
%in = sext <2 x i32> %src to <2 x i64>
@@ -339,11 +314,8 @@ entry:
339314
define <8 x i16> @shufzext_v8i8_v8i16_twoin(<8 x i8> %src1, <8 x i8> %src2, <8 x i8> %b) {
340315
; CHECK-LABEL: shufzext_v8i8_v8i16_twoin:
341316
; CHECK: // %bb.0: // %entry
342-
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
343-
; CHECK-NEXT: ushll v1.8h, v1.8b, #0
344-
; CHECK-NEXT: trn1 v0.8h, v0.8h, v1.8h
345-
; CHECK-NEXT: ushll v1.8h, v2.8b, #0
346-
; CHECK-NEXT: mul v0.8h, v0.8h, v1.8h
317+
; CHECK-NEXT: trn1 v0.8b, v0.8b, v1.8b
318+
; CHECK-NEXT: umull v0.8h, v0.8b, v2.8b
347319
; CHECK-NEXT: ret
348320
entry:
349321
%in1 = zext <8 x i8> %src1 to <8 x i16>

0 commit comments

Comments
 (0)