@@ -11369,54 +11369,105 @@ static bool isSingletonEXTMask(ArrayRef<int> M, EVT VT, unsigned &Imm) {
11369
11369
return true;
11370
11370
}
11371
11371
11372
- // Detect patterns of a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3 from
11373
- // v4i32s. This is really a truncate, which we can construct out of (legal)
11374
- // concats and truncate nodes.
11375
- static SDValue ReconstructTruncateFromBuildVector(SDValue V, SelectionDAG &DAG) {
11376
- if (V.getValueType() != MVT::v16i8)
11377
- return SDValue();
11378
- assert(V.getNumOperands() == 16 && "Expected 16 operands on the BUILDVECTOR");
11379
-
11380
- for (unsigned X = 0; X < 4; X++) {
11381
- // Check the first item in each group is an extract from lane 0 of a v4i32
11382
- // or v4i16.
11383
- SDValue BaseExt = V.getOperand(X * 4);
11384
- if (BaseExt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
11385
- (BaseExt.getOperand(0).getValueType() != MVT::v4i16 &&
11386
- BaseExt.getOperand(0).getValueType() != MVT::v4i32) ||
11387
- !isa<ConstantSDNode>(BaseExt.getOperand(1)) ||
11388
- BaseExt.getConstantOperandVal(1) != 0)
11372
+ // Detect patterns like a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3, that
11373
+ // are truncates, which we can construct out of (legal) concats and truncate
11374
+ // nodes.
11375
+ static SDValue ReconstructTruncateFromBuildVector(SDValue V,
11376
+ SelectionDAG &DAG) {
11377
+ EVT BVTy = V.getValueType();
11378
+ if (BVTy != MVT::v16i8 && BVTy != MVT::v8i16 && BVTy != MVT::v8i8 &&
11379
+ BVTy != MVT::v4i16)
11380
+ return SDValue();
11381
+
11382
+ // Only handle truncating BVs.
11383
+ if (V.getOperand(0).getValueType().getSizeInBits() ==
11384
+ BVTy.getScalarSizeInBits())
11385
+ return SDValue();
11386
+
11387
+ SmallVector<SDValue, 4> Sources;
11388
+ uint64_t LastIdx = 0;
11389
+ uint64_t MaxIdx = 0;
11390
+ // Check for sequential indices e.g. i=0, i+1, ..., i=0, i+1, ...
11391
+ for (SDValue Extr : V->ops()) {
11392
+ SDValue SourceVec = Extr.getOperand(0);
11393
+ EVT SourceVecTy = SourceVec.getValueType();
11394
+
11395
+ if (!DAG.getTargetLoweringInfo().isTypeLegal(SourceVecTy))
11389
11396
return SDValue();
11390
- SDValue Base = BaseExt.getOperand(0);
11391
- // And check the other items are extracts from the same vector.
11392
- for (unsigned Y = 1; Y < 4; Y++) {
11393
- SDValue Ext = V.getOperand(X * 4 + Y);
11394
- if (Ext.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
11395
- Ext.getOperand(0) != Base ||
11396
- !isa<ConstantSDNode>(Ext.getOperand(1)) ||
11397
- Ext.getConstantOperandVal(1) != Y)
11397
+ if (!isa<ConstantSDNode>(Extr.getOperand(1)))
11398
+ return SDValue();
11399
+
11400
+ uint64_t CurIdx = Extr.getConstantOperandVal(1);
11401
+ // Allow repeat of sources.
11402
+ if (CurIdx == 0) {
11403
+ // Check if all lanes are used by the BV.
11404
+ if (Sources.size() && Sources[Sources.size() - 1]
11405
+ .getValueType()
11406
+ .getVectorMinNumElements() != LastIdx + 1)
11398
11407
return SDValue();
11399
- }
11408
+ Sources.push_back(SourceVec);
11409
+ } else if (CurIdx != LastIdx + 1)
11410
+ return SDValue();
11411
+
11412
+ LastIdx = CurIdx;
11413
+ MaxIdx = std::max(MaxIdx, CurIdx);
11400
11414
}
11401
11415
11402
- // Turn the buildvector into a series of truncates and concates, which will
11403
- // become uzip1's. Any v4i32s we found get truncated to v4i16, which are
11404
- // concat together to produce 2 v8i16. These are both truncated and concat
11405
- // together.
11416
+ // Check if all lanes are used by the BV.
11417
+ if (Sources[Sources.size() - 1].getValueType().getVectorMinNumElements() !=
11418
+ LastIdx + 1)
11419
+ return SDValue();
11420
+ if (Sources.size() % 2 != 0)
11421
+ return SDValue();
11422
+
11423
+ // At this point we know that we have a truncating BV of extract_vector_elt.
11424
+ // We can just truncate and concat them.
11406
11425
SDLoc DL(V);
11407
- SDValue Trunc[4] = {
11408
- V.getOperand(0).getOperand(0), V.getOperand(4).getOperand(0),
11409
- V.getOperand(8).getOperand(0), V.getOperand(12).getOperand(0)};
11410
- for (SDValue &V : Trunc)
11411
- if (V.getValueType() == MVT::v4i32)
11412
- V = DAG.getNode(ISD::TRUNCATE, DL, MVT::v4i16, V);
11413
- SDValue Concat0 =
11414
- DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16, Trunc[0], Trunc[1]);
11415
- SDValue Concat1 =
11416
- DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16, Trunc[2], Trunc[3]);
11417
- SDValue Trunc0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, Concat0);
11418
- SDValue Trunc1 = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, Concat1);
11419
- return DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, Trunc0, Trunc1);
11426
+ LLVMContext &Ctx = *DAG.getContext();
11427
+ while (Sources.size() > 1) {
11428
+ for (unsigned i = 0; i < Sources.size(); i += 2) {
11429
+ SDValue V1 = Sources[i];
11430
+ SDValue V2 = Sources[i + 1];
11431
+ EVT VT1 = V1.getValueType();
11432
+ EVT VT2 = V2.getValueType();
11433
+
11434
+ if (VT1.is128BitVector()) {
11435
+ VT1 = VT1.changeVectorElementType(
11436
+ VT1.getVectorElementType().getHalfSizedIntegerVT(Ctx));
11437
+ V1 = DAG.getNode(ISD::TRUNCATE, DL, VT1, V1);
11438
+ }
11439
+ if (VT2.is128BitVector()) {
11440
+ VT2 = VT2.changeVectorElementType(
11441
+ VT2.getVectorElementType().getHalfSizedIntegerVT(Ctx));
11442
+ V2 = DAG.getNode(ISD::TRUNCATE, DL, VT2, V2);
11443
+ }
11444
+
11445
+ assert(VT1 == VT2 && "Mismatched types.");
11446
+ Sources[i / 2] =
11447
+ DAG.getNode(ISD::CONCAT_VECTORS, DL,
11448
+ VT1.getDoubleNumVectorElementsVT(Ctx), V1, V2);
11449
+ }
11450
+ Sources.resize(Sources.size() / 2);
11451
+ }
11452
+
11453
+ // We might not have the final type in some cases e.g. <4i32, 4i32> -> 8i8. Do
11454
+ // a final truncating shuffle instead of a concat + trunc.
11455
+ if (Sources[0].getValueType() != BVTy) {
11456
+ SDValue V1 = Sources[0].getOperand(0);
11457
+ SDValue V2 = Sources[0].getOperand(1);
11458
+ V1 = DAG.getNode(DAG.getDataLayout().isLittleEndian() ? ISD::BITCAST
11459
+ : AArch64ISD::NVCAST,
11460
+ DL, BVTy, V1);
11461
+ V2 = DAG.getNode(DAG.getDataLayout().isLittleEndian() ? ISD::BITCAST
11462
+ : AArch64ISD::NVCAST,
11463
+ DL, BVTy, V2);
11464
+
11465
+ SmallVector<int, 8> MaskVec;
11466
+ for (unsigned i = 0; i < BVTy.getVectorNumElements() * 2; i += 2)
11467
+ MaskVec.push_back(i);
11468
+ return DAG.getVectorShuffle(BVTy, DL, V1, V2, MaskVec);
11469
+ }
11470
+ return Sources[0];
11420
11471
}
11421
11472
11422
11473
/// Check if a vector shuffle corresponds to a DUP instructions with a larger
@@ -13305,8 +13356,9 @@ SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
13305
13356
// Detect patterns of a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3 from
13306
13357
// v4i32s. This is really a truncate, which we can construct out of (legal)
13307
13358
// concats and truncate nodes.
13308
- if (SDValue M = ReconstructTruncateFromBuildVector(Op, DAG))
13309
- return M;
13359
+ if (AllLanesExtractElt)
13360
+ if (SDValue M = ReconstructTruncateFromBuildVector(Op, DAG))
13361
+ return M;
13310
13362
13311
13363
// Empirical tests suggest this is rarely worth it for vectors of length <= 2.
13312
13364
if (NumElts >= 4) {
@@ -19096,6 +19148,28 @@ static SDValue performBuildVectorCombine(SDNode *N,
19096
19148
SDLoc DL(N);
19097
19149
EVT VT = N->getValueType(0);
19098
19150
19151
+ // BUILD_VECTOR (extract_elt(Assert[S|Z]ext(x)))
19152
+ // => BUILD_VECTOR (extract_elt(x))
19153
+ SmallVector<SDValue, 8> Ops;
19154
+ bool ExtractExtended = false;
19155
+ for (SDValue Extr : N->ops()) {
19156
+ if (Extr.getOpcode() != ISD::EXTRACT_VECTOR_ELT) {
19157
+ ExtractExtended = false;
19158
+ break;
19159
+ }
19160
+ SDValue ExtractBase = Extr.getOperand(0);
19161
+ if (ExtractBase.getOpcode() == ISD::AssertSext ||
19162
+ ExtractBase.getOpcode() == ISD::AssertZext) {
19163
+ ExtractExtended = true;
19164
+ Ops.push_back(DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
19165
+ Extr.getValueType(), ExtractBase.getOperand(0),
19166
+ Extr.getOperand(1)));
19167
+ } else
19168
+ Ops.push_back(Extr);
19169
+ }
19170
+ if (ExtractExtended)
19171
+ return DAG.getBuildVector(VT, DL, Ops);
19172
+
19099
19173
// A build vector of two extracted elements is equivalent to an
19100
19174
// extract subvector where the inner vector is any-extended to the
19101
19175
// extract_vector_elt VT.
0 commit comments