Skip to content

Commit 189179b

Browse files
committed
[AArch64] Improve lowering of truncating build vectors
1. Look through assert_zext/sext nodes. 2. Generalize `ReconstructTruncateFromBuildVector` to work for more cases. Change-Id: I717a7471986ea4961c71df62912f8dd6f1723118
1 parent 9150858 commit 189179b

File tree

10 files changed

+778
-963
lines changed

10 files changed

+778
-963
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 119 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11369,54 +11369,105 @@ static bool isSingletonEXTMask(ArrayRef<int> M, EVT VT, unsigned &Imm) {
1136911369
return true;
1137011370
}
1137111371

11372-
// Detect patterns of a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3 from
11373-
// v4i32s. This is really a truncate, which we can construct out of (legal)
11374-
// concats and truncate nodes.
11375-
static SDValue ReconstructTruncateFromBuildVector(SDValue V, SelectionDAG &DAG) {
11376-
if (V.getValueType() != MVT::v16i8)
11377-
return SDValue();
11378-
assert(V.getNumOperands() == 16 && "Expected 16 operands on the BUILDVECTOR");
11379-
11380-
for (unsigned X = 0; X < 4; X++) {
11381-
// Check the first item in each group is an extract from lane 0 of a v4i32
11382-
// or v4i16.
11383-
SDValue BaseExt = V.getOperand(X * 4);
11384-
if (BaseExt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
11385-
(BaseExt.getOperand(0).getValueType() != MVT::v4i16 &&
11386-
BaseExt.getOperand(0).getValueType() != MVT::v4i32) ||
11387-
!isa<ConstantSDNode>(BaseExt.getOperand(1)) ||
11388-
BaseExt.getConstantOperandVal(1) != 0)
11372+
// Detect patterns like a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3, that
11373+
// are truncates, which we can construct out of (legal) concats and truncate
11374+
// nodes.
11375+
static SDValue ReconstructTruncateFromBuildVector(SDValue V,
11376+
SelectionDAG &DAG) {
11377+
EVT BVTy = V.getValueType();
11378+
if (BVTy != MVT::v16i8 && BVTy != MVT::v8i16 && BVTy != MVT::v8i8 &&
11379+
BVTy != MVT::v4i16)
11380+
return SDValue();
11381+
11382+
// Only handle truncating BVs.
11383+
if (V.getOperand(0).getValueType().getSizeInBits() ==
11384+
BVTy.getScalarSizeInBits())
11385+
return SDValue();
11386+
11387+
SmallVector<SDValue, 4> Sources;
11388+
uint64_t LastIdx = 0;
11389+
uint64_t MaxIdx = 0;
11390+
// Check for sequential indices e.g. i=0, i+1, ..., i=0, i+1, ...
11391+
for (SDValue Extr : V->ops()) {
11392+
SDValue SourceVec = Extr.getOperand(0);
11393+
EVT SourceVecTy = SourceVec.getValueType();
11394+
11395+
if (!DAG.getTargetLoweringInfo().isTypeLegal(SourceVecTy))
1138911396
return SDValue();
11390-
SDValue Base = BaseExt.getOperand(0);
11391-
// And check the other items are extracts from the same vector.
11392-
for (unsigned Y = 1; Y < 4; Y++) {
11393-
SDValue Ext = V.getOperand(X * 4 + Y);
11394-
if (Ext.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
11395-
Ext.getOperand(0) != Base ||
11396-
!isa<ConstantSDNode>(Ext.getOperand(1)) ||
11397-
Ext.getConstantOperandVal(1) != Y)
11397+
if (!isa<ConstantSDNode>(Extr.getOperand(1)))
11398+
return SDValue();
11399+
11400+
uint64_t CurIdx = Extr.getConstantOperandVal(1);
11401+
// Allow repeat of sources.
11402+
if (CurIdx == 0) {
11403+
// Check if all lanes are used by the BV.
11404+
if (Sources.size() && Sources[Sources.size() - 1]
11405+
.getValueType()
11406+
.getVectorMinNumElements() != LastIdx + 1)
1139811407
return SDValue();
11399-
}
11408+
Sources.push_back(SourceVec);
11409+
} else if (CurIdx != LastIdx + 1)
11410+
return SDValue();
11411+
11412+
LastIdx = CurIdx;
11413+
MaxIdx = std::max(MaxIdx, CurIdx);
1140011414
}
1140111415

11402-
// Turn the buildvector into a series of truncates and concates, which will
11403-
// become uzip1's. Any v4i32s we found get truncated to v4i16, which are
11404-
// concat together to produce 2 v8i16. These are both truncated and concat
11405-
// together.
11416+
// Check if all lanes are used by the BV.
11417+
if (Sources[Sources.size() - 1].getValueType().getVectorMinNumElements() !=
11418+
LastIdx + 1)
11419+
return SDValue();
11420+
if (Sources.size() % 2 != 0)
11421+
return SDValue();
11422+
11423+
// At this point we know that we have a truncating BV of extract_vector_elt.
11424+
// We can just truncate and concat them.
1140611425
SDLoc DL(V);
11407-
SDValue Trunc[4] = {
11408-
V.getOperand(0).getOperand(0), V.getOperand(4).getOperand(0),
11409-
V.getOperand(8).getOperand(0), V.getOperand(12).getOperand(0)};
11410-
for (SDValue &V : Trunc)
11411-
if (V.getValueType() == MVT::v4i32)
11412-
V = DAG.getNode(ISD::TRUNCATE, DL, MVT::v4i16, V);
11413-
SDValue Concat0 =
11414-
DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16, Trunc[0], Trunc[1]);
11415-
SDValue Concat1 =
11416-
DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16, Trunc[2], Trunc[3]);
11417-
SDValue Trunc0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, Concat0);
11418-
SDValue Trunc1 = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, Concat1);
11419-
return DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, Trunc0, Trunc1);
11426+
LLVMContext &Ctx = *DAG.getContext();
11427+
while (Sources.size() > 1) {
11428+
for (unsigned i = 0; i < Sources.size(); i += 2) {
11429+
SDValue V1 = Sources[i];
11430+
SDValue V2 = Sources[i + 1];
11431+
EVT VT1 = V1.getValueType();
11432+
EVT VT2 = V2.getValueType();
11433+
11434+
if (VT1.is128BitVector()) {
11435+
VT1 = VT1.changeVectorElementType(
11436+
VT1.getVectorElementType().getHalfSizedIntegerVT(Ctx));
11437+
V1 = DAG.getNode(ISD::TRUNCATE, DL, VT1, V1);
11438+
}
11439+
if (VT2.is128BitVector()) {
11440+
VT2 = VT2.changeVectorElementType(
11441+
VT2.getVectorElementType().getHalfSizedIntegerVT(Ctx));
11442+
V2 = DAG.getNode(ISD::TRUNCATE, DL, VT2, V2);
11443+
}
11444+
11445+
assert(VT1 == VT2 && "Mismatched types.");
11446+
Sources[i / 2] =
11447+
DAG.getNode(ISD::CONCAT_VECTORS, DL,
11448+
VT1.getDoubleNumVectorElementsVT(Ctx), V1, V2);
11449+
}
11450+
Sources.resize(Sources.size() / 2);
11451+
}
11452+
11453+
// We might not have the final type in some cases e.g. <4i32, 4i32> -> 8i8. Do
11454+
// a final truncating shuffle instead of a concat + trunc.
11455+
if (Sources[0].getValueType() != BVTy) {
11456+
SDValue V1 = Sources[0].getOperand(0);
11457+
SDValue V2 = Sources[0].getOperand(1);
11458+
V1 = DAG.getNode(DAG.getDataLayout().isLittleEndian() ? ISD::BITCAST
11459+
: AArch64ISD::NVCAST,
11460+
DL, BVTy, V1);
11461+
V2 = DAG.getNode(DAG.getDataLayout().isLittleEndian() ? ISD::BITCAST
11462+
: AArch64ISD::NVCAST,
11463+
DL, BVTy, V2);
11464+
11465+
SmallVector<int, 8> MaskVec;
11466+
for (unsigned i = 0; i < BVTy.getVectorNumElements() * 2; i += 2)
11467+
MaskVec.push_back(i);
11468+
return DAG.getVectorShuffle(BVTy, DL, V1, V2, MaskVec);
11469+
}
11470+
return Sources[0];
1142011471
}
1142111472

1142211473
/// Check if a vector shuffle corresponds to a DUP instructions with a larger
@@ -13305,8 +13356,9 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
1330513356
// Detect patterns of a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3 from
1330613357
// v4i32s. This is really a truncate, which we can construct out of (legal)
1330713358
// concats and truncate nodes.
13308-
if (SDValue M = ReconstructTruncateFromBuildVector(Op, DAG))
13309-
return M;
13359+
if (AllLanesExtractElt)
13360+
if (SDValue M = ReconstructTruncateFromBuildVector(Op, DAG))
13361+
return M;
1331013362

1331113363
// Empirical tests suggest this is rarely worth it for vectors of length <= 2.
1331213364
if (NumElts >= 4) {
@@ -19096,6 +19148,28 @@ static SDValue performBuildVectorCombine(SDNode *N,
1909619148
SDLoc DL(N);
1909719149
EVT VT = N->getValueType(0);
1909819150

19151+
// BUILD_VECTOR (extract_elt(Assert[S|Z]ext(x)))
19152+
// => BUILD_VECTOR (extract_elt(x))
19153+
SmallVector<SDValue, 8> Ops;
19154+
bool ExtractExtended = false;
19155+
for (SDValue Extr : N->ops()) {
19156+
if (Extr.getOpcode() != ISD::EXTRACT_VECTOR_ELT) {
19157+
ExtractExtended = false;
19158+
break;
19159+
}
19160+
SDValue ExtractBase = Extr.getOperand(0);
19161+
if (ExtractBase.getOpcode() == ISD::AssertSext ||
19162+
ExtractBase.getOpcode() == ISD::AssertZext) {
19163+
ExtractExtended = true;
19164+
Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
19165+
Extr.getValueType(), ExtractBase.getOperand(0),
19166+
Extr.getOperand(1)));
19167+
} else
19168+
Ops.push_back(Extr);
19169+
}
19170+
if (ExtractExtended)
19171+
return DAG.getBuildVector(VT, DL, Ops);
19172+
1909919173
// A build vector of two extracted elements is equivalent to an
1910019174
// extract subvector where the inner vector is any-extended to the
1910119175
// extract_vector_elt VT.

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7595,17 +7595,13 @@ defm USHR : SIMDVectorRShiftBHSD<1, 0b00000, "ushr", AArch64vlshr>;
75957595
defm USRA : SIMDVectorRShiftBHSDTied<1, 0b00010, "usra",
75967596
TriOpFrag<(add_and_or_is_add node:$LHS, (AArch64vlshr node:$MHS, node:$RHS))> >;
75977597

7598-
def VImm0080: PatLeaf<(AArch64movi_shift (i32 128), (i32 0))>;
7599-
def VImm00008000: PatLeaf<(AArch64movi_shift (i32 128), (i32 8))>;
7600-
def VImm0000000080000000: PatLeaf<(AArch64NvCast (v2f64 (fneg (AArch64NvCast (v4i32 (AArch64movi_shift (i32 128), (i32 24)))))))>;
7601-
76027598
// RADDHN patterns for when RSHRN shifts by half the size of the vector element
7603-
def : Pat<(v8i8 (trunc (AArch64vlshr (add (v8i16 V128:$Vn), VImm0080), (i32 8)))),
7599+
def : Pat<(v8i8 (trunc (AArch64vlshr (add (v8i16 V128:$Vn), (AArch64movi_shift (i32 128), (i32 0))), (i32 8)))),
76047600
(RADDHNv8i16_v8i8 V128:$Vn, (v8i16 (MOVIv2d_ns (i32 0))))>;
7605-
def : Pat<(v4i16 (trunc (AArch64vlshr (add (v4i32 V128:$Vn), VImm00008000), (i32 16)))),
7601+
def : Pat<(v4i16 (trunc (AArch64vlshr (add (v4i32 V128:$Vn), (AArch64movi_shift (i32 128), (i32 8))), (i32 16)))),
76067602
(RADDHNv4i32_v4i16 V128:$Vn, (v4i32 (MOVIv2d_ns (i32 0))))>;
76077603
let AddedComplexity = 5 in
7608-
def : Pat<(v2i32 (trunc (AArch64vlshr (add (v2i64 V128:$Vn), VImm0000000080000000), (i32 32)))),
7604+
def : Pat<(v2i32 (trunc (AArch64vlshr (add (v2i64 V128:$Vn), (AArch64dup (i64 2147483648))), (i32 32)))),
76097605
(RADDHNv2i64_v2i32 V128:$Vn, (v2i64 (MOVIv2d_ns (i32 0))))>;
76107606
def : Pat<(v8i8 (int_aarch64_neon_rshrn (v8i16 V128:$Vn), (i32 8))),
76117607
(RADDHNv8i16_v8i8 V128:$Vn, (v8i16 (MOVIv2d_ns (i32 0))))>;
@@ -7617,20 +7613,20 @@ def : Pat<(v2i32 (int_aarch64_neon_rshrn (v2i64 V128:$Vn), (i32 32))),
76177613
// RADDHN2 patterns for when RSHRN shifts by half the size of the vector element
76187614
def : Pat<(v16i8 (concat_vectors
76197615
(v8i8 V64:$Vd),
7620-
(v8i8 (trunc (AArch64vlshr (add (v8i16 V128:$Vn), VImm0080), (i32 8)))))),
7616+
(v8i8 (trunc (AArch64vlshr (add (v8i16 V128:$Vn), (AArch64movi_shift (i32 128), (i32 0))), (i32 8)))))),
76217617
(RADDHNv8i16_v16i8
76227618
(INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn,
76237619
(v8i16 (MOVIv2d_ns (i32 0))))>;
76247620
def : Pat<(v8i16 (concat_vectors
76257621
(v4i16 V64:$Vd),
7626-
(v4i16 (trunc (AArch64vlshr (add (v4i32 V128:$Vn), VImm00008000), (i32 16)))))),
7622+
(v4i16 (trunc (AArch64vlshr (add (v4i32 V128:$Vn), (AArch64movi_shift (i32 128), (i32 8))), (i32 16)))))),
76277623
(RADDHNv4i32_v8i16
76287624
(INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn,
76297625
(v4i32 (MOVIv2d_ns (i32 0))))>;
76307626
let AddedComplexity = 5 in
76317627
def : Pat<(v4i32 (concat_vectors
76327628
(v2i32 V64:$Vd),
7633-
(v2i32 (trunc (AArch64vlshr (add (v2i64 V128:$Vn), VImm0000000080000000), (i32 32)))))),
7629+
(v2i32 (trunc (AArch64vlshr (add (v2i64 V128:$Vn), (AArch64dup (i64 2147483648))), (i32 32)))))),
76347630
(RADDHNv2i64_v4i32
76357631
(INSERT_SUBREG (IMPLICIT_DEF), V64:$Vd, dsub), V128:$Vn,
76367632
(v2i64 (MOVIv2d_ns (i32 0))))>;

llvm/test/CodeGen/AArch64/arm64-convert-v4f64.ll

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ define <4 x i16> @fptosi_v4f64_to_v4i16(ptr %ptr) {
88
; CHECK-NEXT: ldp q0, q1, [x0]
99
; CHECK-NEXT: fcvtzs v1.2d, v1.2d
1010
; CHECK-NEXT: fcvtzs v0.2d, v0.2d
11-
; CHECK-NEXT: xtn v1.2s, v1.2d
12-
; CHECK-NEXT: xtn v0.2s, v0.2d
13-
; CHECK-NEXT: uzp1 v0.4h, v0.4h, v1.4h
11+
; CHECK-NEXT: uzp1 v0.4s, v0.4s, v1.4s
12+
; CHECK-NEXT: xtn v0.4h, v0.4s
1413
; CHECK-NEXT: ret
1514
%tmp1 = load <4 x double>, ptr %ptr
1615
%tmp2 = fptosi <4 x double> %tmp1 to <4 x i16>
@@ -26,13 +25,10 @@ define <8 x i8> @fptosi_v4f64_to_v4i8(ptr %ptr) {
2625
; CHECK-NEXT: fcvtzs v1.2d, v1.2d
2726
; CHECK-NEXT: fcvtzs v3.2d, v3.2d
2827
; CHECK-NEXT: fcvtzs v2.2d, v2.2d
29-
; CHECK-NEXT: xtn v0.2s, v0.2d
30-
; CHECK-NEXT: xtn v1.2s, v1.2d
31-
; CHECK-NEXT: xtn v3.2s, v3.2d
32-
; CHECK-NEXT: xtn v2.2s, v2.2d
33-
; CHECK-NEXT: uzp1 v0.4h, v1.4h, v0.4h
34-
; CHECK-NEXT: uzp1 v1.4h, v2.4h, v3.4h
35-
; CHECK-NEXT: uzp1 v0.8b, v1.8b, v0.8b
28+
; CHECK-NEXT: uzp1 v0.4s, v1.4s, v0.4s
29+
; CHECK-NEXT: uzp1 v1.4s, v2.4s, v3.4s
30+
; CHECK-NEXT: uzp1 v0.8h, v1.8h, v0.8h
31+
; CHECK-NEXT: xtn v0.8b, v0.8h
3632
; CHECK-NEXT: ret
3733
%tmp1 = load <8 x double>, ptr %ptr
3834
%tmp2 = fptosi <8 x double> %tmp1 to <8 x i8>
@@ -72,9 +68,8 @@ define <4 x i16> @fptoui_v4f64_to_v4i16(ptr %ptr) {
7268
; CHECK-NEXT: ldp q0, q1, [x0]
7369
; CHECK-NEXT: fcvtzs v1.2d, v1.2d
7470
; CHECK-NEXT: fcvtzs v0.2d, v0.2d
75-
; CHECK-NEXT: xtn v1.2s, v1.2d
76-
; CHECK-NEXT: xtn v0.2s, v0.2d
77-
; CHECK-NEXT: uzp1 v0.4h, v0.4h, v1.4h
71+
; CHECK-NEXT: uzp1 v0.4s, v0.4s, v1.4s
72+
; CHECK-NEXT: xtn v0.4h, v0.4s
7873
; CHECK-NEXT: ret
7974
%tmp1 = load <4 x double>, ptr %ptr
8075
%tmp2 = fptoui <4 x double> %tmp1 to <4 x i16>

llvm/test/CodeGen/AArch64/fp-conversion-to-tbl.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,8 @@ define void @fptoui_v8f32_to_v8i8_no_loop(ptr %A, ptr %dst) {
7373
; CHECK-NEXT: ldp q0, q1, [x0]
7474
; CHECK-NEXT: fcvtzs.4s v1, v1
7575
; CHECK-NEXT: fcvtzs.4s v0, v0
76-
; CHECK-NEXT: xtn.4h v1, v1
77-
; CHECK-NEXT: xtn.4h v0, v0
78-
; CHECK-NEXT: uzp1.8b v0, v0, v1
76+
; CHECK-NEXT: uzp1.8h v0, v0, v1
77+
; CHECK-NEXT: xtn.8b v0, v0
7978
; CHECK-NEXT: str d0, [x1]
8079
; CHECK-NEXT: ret
8180
entry:

0 commit comments

Comments
 (0)