@@ -9800,6 +9800,26 @@ SDValue AArch64TargetLowering::LowerCTPOP_PARITY(SDValue Op,
9800
9800
Val = DAG.getBitcast(VT8Bit, Val);
9801
9801
Val = DAG.getNode(ISD::CTPOP, DL, VT8Bit, Val);
9802
9802
9803
+ if (Subtarget->hasDotProd() && VT.getScalarSizeInBits() != 16) {
9804
+ EVT DT = VT == MVT::v2i64 ? MVT::v4i32 : VT;
9805
+ SDValue Zeros = DAG.getSplatBuildVector(
9806
+ DT, DL, DAG.getConstant(0, DL, DT.getScalarType()));
9807
+ SDValue Ones =
9808
+ DAG.getSplatBuildVector(VT8Bit, DL, DAG.getConstant(1, DL, MVT::i8));
9809
+
9810
+ if (VT == MVT::v2i64) {
9811
+ Val = DAG.getNode(AArch64ISD::UDOT, DL, DT, Zeros, Ones, Val);
9812
+ Val = DAG.getNode(AArch64ISD::UADDLP, DL, VT, Val);
9813
+ } else if (VT == MVT::v2i32) {
9814
+ Val = DAG.getNode(AArch64ISD::UDOT, DL, DT, Zeros, Ones, Val);
9815
+ } else if (VT == MVT::v4i32) {
9816
+ Val = DAG.getNode(AArch64ISD::UDOT, DL, DT, Zeros, Ones, Val);
9817
+ } else {
9818
+ llvm_unreachable("Unexpected type for custom ctpop lowering");
9819
+ }
9820
+
9821
+ return Val;
9822
+ }
9803
9823
// Widen v8i8/v16i8 CTPOP result to VT by repeatedly widening pairwise adds.
9804
9824
unsigned EltSize = 8;
9805
9825
unsigned NumElts = VT.is64BitVector() ? 8 : 16;
0 commit comments