@@ -6379,42 +6379,86 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
6379
6379
assert((Op1VT.getVectorElementType() == MVT::i8 ||
6380
6380
Op1VT.getVectorElementType() == MVT::i16) &&
6381
6381
"Expected 8-bit or 16-bit characters.");
6382
- assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
6383
6382
assert(Op1VT.getVectorElementType() == Op2VT.getVectorElementType() &&
6384
6383
"Operand type mismatch.");
6385
- assert(Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements() &&
6386
- "Invalid operands.");
6387
-
6388
- // Wrap the search vector in a scalable vector.
6389
- EVT OpContainerVT = getContainerForFixedLengthVector(DAG, Op2VT);
6390
- Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
6391
-
6392
- // If the result is scalable, we need to broadbast the search vector across
6393
- // the SVE register and then carry out the MATCH.
6394
- if (ResVT.isScalableVector()) {
6395
- Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
6396
- DAG.getTargetConstant(0, dl, MVT::i64));
6384
+ assert(!Op2VT.isScalableVector() && "Search vector cannot be scalable.");
6385
+
6386
+ // Note: Currently Op1 needs to be v16i8, v8i16, or the scalable versions.
6387
+ // In the future we could support other types (e.g. v8i8).
6388
+ assert(Op1VT.getSizeInBits().getKnownMinValue() == 128 &&
6389
+ "Unsupported first operand type.");
6390
+
6391
+ // Scalable vector type used to wrap operands.
6392
+ // A single container is enough for both operands because ultimately the
6393
+ // operands will have to be wrapped to the same type (nxv16i8 or nxv8i16).
6394
+ EVT OpContainerVT = Op1VT.isScalableVector()
6395
+ ? Op1VT
6396
+ : getContainerForFixedLengthVector(DAG, Op1VT);
6397
+
6398
+ // Wrap Op2 in a scalable register, and splat it if necessary.
6399
+ if (Op1VT.getVectorMinNumElements() == Op2VT.getVectorNumElements()) {
6400
+ // If Op1 and Op2 have the same number of elements we can trivially
6401
+ // wrapping Op2 in an SVE register.
6402
+ Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
6403
+ // If the result is scalable, we need to broadcast Op2 to a full SVE
6404
+ // register.
6405
+ if (ResVT.isScalableVector())
6406
+ Op2 = DAG.getNode(AArch64ISD::DUPLANE128, dl, OpContainerVT, Op2,
6407
+ DAG.getTargetConstant(0, dl, MVT::i64));
6408
+ } else {
6409
+ // If Op1 and Op2 have different number of elements, we need to broadcast
6410
+ // Op2. Ideally we would use a AArch64ISD::DUPLANE* node for this
6411
+ // similarly to the above, but unfortunately it seems we are missing some
6412
+ // patterns for this. So, in alternative, we splat Op2 through a splat of
6413
+ // a scalable vector extract. This idiom, though a bit more verbose, is
6414
+ // supported and get us the MOV instruction we want.
6415
+
6416
+ // Some types we need. We'll use an integer type with `Op2BitWidth' bits
6417
+ // to wrap Op2 and simulate the DUPLANE.
6418
+ unsigned Op2BitWidth = Op2VT.getFixedSizeInBits();
6419
+ MVT Op2IntVT = MVT::getIntegerVT(Op2BitWidth);
6420
+ MVT Op2FixedVT = MVT::getVectorVT(Op2IntVT, 128 / Op2BitWidth);
6421
+ EVT Op2ScalableVT = getContainerForFixedLengthVector(DAG, Op2FixedVT);
6422
+ // Widen Op2 to a full 128-bit register. We need this to wrap Op2 in an
6423
+ // SVE register before doing the extract and splat.
6424
+ // It is unlikely we'll be widening from types other than v8i8 or v4i16,
6425
+ // so in practice this loop will run for a single iteration.
6426
+ while (Op2VT.getFixedSizeInBits() != 128) {
6427
+ Op2VT = Op2VT.getDoubleNumVectorElementsVT(*DAG.getContext());
6428
+ Op2 = DAG.getNode(ISD::CONCAT_VECTORS, dl, Op2VT, Op2,
6429
+ DAG.getUNDEF(Op2.getValueType()));
6430
+ }
6431
+ // Wrap Op2 in a scalable vector and do the splat of its 0-index lane.
6432
+ Op2 = convertToScalableVector(DAG, OpContainerVT, Op2);
6433
+ Op2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op2IntVT,
6434
+ DAG.getBitcast(Op2ScalableVT, Op2),
6435
+ DAG.getConstant(0, dl, MVT::i64));
6436
+ Op2 = DAG.getSplatVector(Op2ScalableVT, dl, Op2);
6437
+ Op2 = DAG.getBitcast(OpContainerVT, Op2);
6438
+ }
6439
+
6440
+ // If the result is scalable, we just need to carry out the MATCH.
6441
+ if (ResVT.isScalableVector())
6397
6442
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, ResVT, ID, Mask, Op1,
6398
6443
Op2);
6399
- }
6400
6444
6401
6445
// If the result is fixed, we can still use MATCH but we need to wrap the
6402
6446
// first operand and the mask in scalable vectors before doing so.
6403
- EVT MatchVT = OpContainerVT.changeElementType(MVT::i1);
6404
6447
6405
6448
// Wrap the operands.
6406
6449
Op1 = convertToScalableVector(DAG, OpContainerVT, Op1);
6407
6450
Mask = DAG.getNode(ISD::ANY_EXTEND, dl, Op1VT, Mask);
6408
6451
Mask = convertFixedMaskToScalableVector(Mask, DAG);
6409
6452
6410
- // Carry out the match.
6411
- SDValue Match =
6412
- DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MatchVT, ID, Mask, Op1, Op2);
6453
+ // Carry out the match and extract it.
6454
+ SDValue Match = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
6455
+ Mask.getValueType(), ID, Mask, Op1, Op2);
6456
+ Match = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
6457
+ DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
6458
+ DAG.getVectorIdxConstant(0, dl));
6413
6459
6414
- // Extract and return the result.
6415
- return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, Op1VT,
6416
- DAG.getNode(ISD::SIGN_EXTEND, dl, OpContainerVT, Match),
6417
- DAG.getVectorIdxConstant(0, dl));
6460
+ // Truncate and return the result.
6461
+ return DAG.getNode(ISD::TRUNCATE, dl, ResVT, Match);
6418
6462
}
6419
6463
}
6420
6464
}
0 commit comments