Skip to content

[WIP][RFC] Implementation for SVE2 long operations #89310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

UsmanNadeem
Copy link
Contributor

@UsmanNadeem UsmanNadeem commented Apr 18, 2024

I have written this patch to show the kind of optimized codegen we should expect and to get feedback on the codegen approach.

Also note that the loop vectorizer currently does not generate wide scalable vector IR (probably because of the cost model), so the test case attached is manually converted from the fixed vector IR to scalable vector form.

There are a few issues which make SVE2 widening op implementation not so straight forward.

  • SVE2 widening operations differ from NEON in the sense that instead of operating on the hi/lo halves they operate on even/odd halves of the vector . This makes things tricky because in order to get the "natural" form of the vector i.e. hi/lo halves we need to interleave the results from the bottom/top instructions. So there is no 1-1 mapping due to the extra shuffles needed and we have to bundle the bot/top+interleave.
  • It is difficult to do patterns on legal types because legalization will produce unpacks which are of hi/lo form. Even/odd form of operations needs to reason about lanes across the middle boundary, thus we need access to wide operations of illegal types for lowering to sve2 long operations.
  • Since the order of lanes in the input vectors do not matter for basic arithmetic operations (as long as both LHS and RHS lanes correspond to each other), if a widening instruction feeds into another widening instruction then some of the intermediate shuffles can be optimized away. Right now I have added the logic to the same function and handle this during lowering but I feel that it can be a combine on the vector_interleave node (I am not fully sure, need to check the DAGs for various types.).

Example:

// Legalized operations are difficult to pattern match for SVE2
8xi16 V = 7 6 5 4 3 2 1 0 
v_unpklo  = 3 2 1 0
v_unpkhi= 7 6 5 4
add(v_unpklo, v_unpklo) = 6 4 2 0
add(v_unpkhi, v_unpkhi) = 14 12 10 8

// we cannot look at any of the above two legal adds in isolation.

// SVE addlb(v, v) will work on lanes 6 4 2 0 and produce 12 8 4 0
// SVE addlt(v, v) will work on lanes  7 5 3 1 and produce 14 10 6 2
// We need shuffles to get  [14 12 10 8] and [6 4 2 0].
// There is no root node using which we can access both adds and 
// we need to generate both bot/top + shuffles at the same time.

Additionally if the results are stored we could use structured stores to get rid of the zips, but these stores are generated before the selection dag by the Interleaved Access pass, so we might need something for the dag phase as well.

Change-Id: I6e8f70342ff25f6ab21cd5666c9085be0fa2e206
@llvmbot
Copy link
Member

llvmbot commented Apr 18, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Usman Nadeem (UsmanNadeem)

Changes

I have written this patch to show the kind of optimized codegen we should expect and to get feedback on the codegen approach.

Also note that the loop vectorizer currently does not generate wide scalable vector IR (probably because of the cost model), so the test case attached is manually converted from the fixed vector IR to scalable vector form.

There are a few issues which make SVE2 widening op implementation not so straight forward.

  • SVE2 widening operations differ from NEON in the sense that instead of operating on the hi/lo halves they operate on even/odd halves of the vector . This makes things tricky because in order to get the "natural" form of the vector i.e. hi/lo halves we need to interleave the results from the bottom/top instructions. So there is no 1-1 mapping due to the extra shuffles needed and we have to bundle the bot/top+interleave.
  • It is difficult to do patterns on legal types because legalization will produce unpacks which are of hi/lo form. Even/odd form of operations needs to reason about lanes across the middle boundary, thus we need access to wide operations of illegal types for lowering to sve2 long operations.
  • Since the order of lanes in the input vectors do not matter for basic arithmetic operations (as long as both LHS and RHS lanes correspond to each other), if a widening instruction feeds into another widening instruction then some of the intermediate shuffles can be optimized away. Right now I have added the logic to the same function and handle this during lowering but I feel that it can be a combine on the vector_interleave node (I am not fully sure, need to check the DAGs for various types.).

Example:

// Legalized operations are difficult to pattern match for SVE2
8xi16 V = 7 6 5 4 3 2 1 0 
v_unpklo  = 3 2 1 0
v_unpkhi= 7 6 5 4
add(v_unpklo, v_unpklo) = 6 4 2 0
add(v_unpkhi, v_unpkhi) = 14 12 10 8

// we cannot look at any of the above two legal adds in isolation.

// SVE addlb(v, v) will work on lanes 6 4 2 0 and produce 12 8 4 0
// SVE addlt(v, v) will work on lanes  7 5 3 1 and produce 14 10 6 2
// We need shuffles to get  [14 12 10 8] and [6 4 2 0].
// There is no root node using which we can access both adds and 
// we need to generate both bot/top + shuffles at the same time.

Patch is 45.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89310.diff

5 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+292)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+3)
  • (modified) llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td (+18-2)
  • (modified) llvm/test/CodeGen/AArch64/sve-doublereduct.ll (+7-9)
  • (added) llvm/test/CodeGen/AArch64/sve2-uaddl.ll (+636)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7947d73f9a4dd0..fa8ec9b7a55f21 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -49,6 +49,7 @@
 #include "llvm/CodeGen/MachineMemOperand.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/RuntimeLibcalls.h"
+#include "llvm/CodeGen/SDPatternMatch.h"
 #include "llvm/CodeGen/SelectionDAG.h"
 #include "llvm/CodeGen/SelectionDAGNodes.h"
 #include "llvm/CodeGen/TargetCallingConv.h"
@@ -104,6 +105,7 @@
 
 using namespace llvm;
 using namespace llvm::PatternMatch;
+namespace sd = llvm::SDPatternMatch;
 
 #define DEBUG_TYPE "aarch64-lower"
 
@@ -1416,6 +1418,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         setOperationAction(ISD::OR, VT, Custom);
     }
 
+    // Illegal wide integer scalable vector types.
+    if (Subtarget->hasSVE2orSME()) {
+      for (auto VT : {MVT::nxv16i16, MVT::nxv16i32, MVT::nxv16i64})
+        setOperationAction(ISD::ADD, VT, Custom);
+      for (auto VT : {MVT::nxv8i32, MVT::nxv8i64})
+        setOperationAction(ISD::ADD, VT, Custom);
+      setOperationAction(ISD::ADD, MVT::nxv4i64, Custom);
+    }
+
     // Illegal unpacked integer vector types.
     for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) {
       setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
@@ -2725,6 +2736,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::CTTZ_ELTS)
     MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
     MAKE_CASE(AArch64ISD::URSHR_I_PRED)
+    MAKE_CASE(AArch64ISD::UADDLB)
+    MAKE_CASE(AArch64ISD::UADDLT)
   }
 #undef MAKE_CASE
   return nullptr;
@@ -25081,6 +25094,282 @@ void AArch64TargetLowering::ReplaceBITCASTResults(
   Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Op));
 }
 
+static bool matchUADDLOps(SDNode *N, SelectionDAG &DAG, SDValue &A, SDValue &B,
+                          unsigned &BotOpc, unsigned &TopOpc) {
+  BotOpc = AArch64ISD::UADDLB;
+  TopOpc = AArch64ISD::UADDLT;
+  if (sd_match(N, sd::m_Add(sd::m_OneUse(sd::m_ZExt(sd::m_Value(A))),
+                            sd::m_OneUse(sd::m_ZExt(sd::m_Value(B))))))
+
+    return true;
+
+#if 0
+  // Extended loads.
+  if (sd_match(N, sd::m_Add(sd::m_OneUse(sd::m_ZExt(sd::m_Value(A))),
+                            sd::m_OneUse(sd::m_Value(B))))) {
+    auto *LDB = dyn_cast<LoadSDNode>(B);
+    if (LDB && LDB->getExtensionType() == ISD::ZEXTLOAD) {
+      B = DAG.getLoad(LDB->getMemoryVT(), SDLoc(LDB), LDB->getChain(),
+                      LDB->getBasePtr(), LDB->getMemOperand());
+      return true;
+    }
+  } else if (sd_match(N, sd::m_Add(sd::m_OneUse(sd::m_Value(A)),
+                                   sd::m_OneUse(sd::m_Value(B)))) &&
+             isa<LoadSDNode>(A) && isa<LoadSDNode>(B)) {
+    auto *LDA = cast<LoadSDNode>(A);
+    auto *LDB = cast<LoadSDNode>(B);
+    if (LDA->getExtensionType() == ISD::ZEXTLOAD &&
+        LDB->getExtensionType() == ISD::ZEXTLOAD) {
+      A = DAG.getLoad(LDA->getMemoryVT(), SDLoc(LDA), LDA->getChain(),
+                      LDA->getBasePtr(), LDA->getMemOperand());
+      B = DAG.getLoad(LDB->getMemoryVT(), SDLoc(LDB), LDB->getChain(),
+                      LDB->getBasePtr(), LDB->getMemOperand());
+      return true;
+    }
+  }
+#endif
+  return false;
+}
+static bool replaceIntOpWithSVE2LongOp(SDNode *N,
+                                       SmallVectorImpl<SDValue> &Results,
+                                       SelectionDAG &DAG,
+                                       const AArch64Subtarget *Subtarget) {
+  if (!Subtarget->hasSVE2orSME())
+    return false;
+
+  EVT VT = N->getValueType(0);
+  LLVMContext &Ctx = *DAG.getContext();
+  SDLoc DL(N);
+  SDValue LHS, RHS;
+  unsigned BotOpc, TopOpc;
+
+  auto CreateLongOpPair = [&](SDValue LHS,
+                              SDValue RHS) -> std::pair<SDValue, SDValue> {
+    EVT WideResVT = LHS.getValueType()
+                        .widenIntegerVectorElementType(Ctx)
+                        .getHalfNumVectorElementsVT(Ctx);
+    SDValue Even = DAG.getNode(BotOpc, DL, WideResVT, LHS, RHS);
+    SDValue Odd = DAG.getNode(TopOpc, DL, WideResVT, LHS, RHS);
+    return std::make_pair(Even, Odd);
+  };
+
+  bool MatchedLongOp = matchUADDLOps(N, DAG, LHS, RHS, BotOpc, TopOpc);
+  // Should also work for similar long instructions.
+  // if (!MatchedLongOp) MatchedLongOp = match<OtherLongInstr>Ops(...);
+  if (!MatchedLongOp || LHS.getValueType() != RHS.getValueType())
+    return false;
+  EVT UnExtVT = LHS.getValueType();
+
+  // 128-bit unextended operands.
+  if (UnExtVT == MVT::nxv16i8 || UnExtVT == MVT::nxv8i16 ||
+      UnExtVT == MVT::nxv4i32) {
+    auto [Even, Odd] = CreateLongOpPair(LHS, RHS);
+    EVT WideResVT = Even.getValueType();
+    // Widening operations deinterleaves the results. Shuffle them to get
+    // their natural order.
+    SDValue Interleave =
+        DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
+                    DAG.getVTList(WideResVT, WideResVT), Even, Odd);
+    SDValue Concat = DAG.getNode(
+        ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
+        Interleave.getValue(0), Interleave.getValue(1));
+    Results.push_back(DAG.getZExtOrTrunc(Concat, DL, VT));
+    return true;
+  }
+
+  // 256-bit/512-bit unextended operands. Try to optimize by reducing the number
+  // of shuffles in cases where the operands are interleaved from existing
+  // even/odd pairs.
+  if (UnExtVT == MVT::nxv16i16 || UnExtVT == MVT::nxv8i32) {
+    // For the pattern:
+    //   (LHSBot, LHSTop) = vector_interleave(LHSEven, LHSOdd)
+    //   (RHSBot, RHSTop) = vector_interleave(RHSEven, RHSOdd)
+    //   LHS = concat(LHSBot, LHSTop)
+    //   RHS = concat(RHSBot, RHSTop)
+    //   op(zext(LHS), zext(RHS))
+    // We can use the pre-interleaved operands to create the longOp(b|t) and
+    // push the shuffles across the operation.
+    SDValue LHSBot, LHSTop, RHSBot, RHSTop;
+    SDValue LHSEven, LHSOdd, RHSEven, RHSOdd;
+
+    if (!sd_match(LHS, sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(LHSBot),
+                                  sd::m_Value(LHSTop))))
+      return false;
+    if (LHSTop.getNode() != LHSBot.getNode() || LHSTop == LHSBot ||
+        !sd_match(LHSBot.getNode(),
+                  sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(LHSEven),
+                             sd::m_Value(LHSOdd))))
+      return false;
+
+    if (!sd_match(RHS, sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(RHSBot),
+                                  sd::m_Value(RHSTop))))
+      return false;
+    if (RHSTop.getNode() != RHSBot.getNode() || RHSTop == RHSBot ||
+        !sd_match(RHSBot.getNode(),
+                  sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(RHSEven),
+                             sd::m_Value(RHSOdd))))
+      return false;
+
+    // Do the following:
+    //   v0 = longOpb(LHSEven, RHSEven)
+    //   v1 = longOpt(LHSEven, RHSEven)
+    //   v2 = longOpb(LHSOdd, RHSOdd)
+    //   v3 = longOpt(LHSOdd, RHSOdd)
+    //   InterleaveEven = interleave(v0, v2)
+    //   InterleaveOdd  = interleave(v1, v3)
+    //   concat(InterleaveEven[0], InterleaveOdd[0], InterleaveEven[1],
+    //   InterleaveOdd[1])
+    auto [V0, V1] = CreateLongOpPair(LHSEven, RHSEven);
+    auto [V2, V3] = CreateLongOpPair(LHSOdd, RHSOdd);
+    EVT WideResVT = V0.getValueType();
+
+    SDValue InterleaveEven =
+        DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
+                    DAG.getVTList(WideResVT, WideResVT), V0, V2);
+    SDValue InterleaveOdd =
+        DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
+                    DAG.getVTList(WideResVT, WideResVT), V1, V3);
+
+    SDValue Concat0 = DAG.getNode(
+        ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
+        InterleaveEven.getValue(0), InterleaveOdd.getValue(0));
+    SDValue Concat1 = DAG.getNode(
+        ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
+        InterleaveEven.getValue(1), InterleaveOdd.getValue(1));
+    SDValue Concat =
+        DAG.getNode(ISD::CONCAT_VECTORS, DL,
+                    Concat0.getValueType().getDoubleNumVectorElementsVT(Ctx),
+                    Concat0, Concat1);
+    Results.push_back(DAG.getZExtOrTrunc(Concat, DL, VT));
+    return true;
+  }
+
+  if (UnExtVT == MVT::nxv16i32) {
+    // [LHS0, LHS2] = interleave(...)
+    // [LHS1, LHS3] = interleave(...)
+    // LHS = concat(concat(LHS0, LHS1), concat(LHS2, LHS3))
+    // See comments for 256-bit unextended operands to understand
+    // where this pattern comes from.
+    // Example:
+    //   LHS = 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
+    //   LHS0 = 3, 2, 1, 0
+    //   LHS1 = 7, 6, 5, 4
+    //   LHS2 = 11, 10, 9, 8
+    //   LHS3 = 15, 14, 13, 12
+    // After Deinterleaving/pre-interleaved values:
+    //   LHS0 = 10, 8, 2, 0
+    //   LHS1 = 14, 12, 6, 4
+    //   LHS2 = 11, 9, 3, 1
+    //   LHS3 = 15, 13, 7, 5
+
+    SDValue LHS0, LHS1, LHS2, LHS3;
+    SDValue RHS0, RHS1, RHS2, RHS3;
+    if (!sd_match(LHS,
+                  sd::m_Node(ISD::CONCAT_VECTORS,
+                             sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(LHS0),
+                                        sd::m_Value(LHS1)),
+                             sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(LHS2),
+                                        sd::m_Value(LHS3)))))
+      return false;
+    if (!sd_match(RHS,
+                  sd::m_Node(ISD::CONCAT_VECTORS,
+                             sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(RHS0),
+                                        sd::m_Value(RHS1)),
+                             sd::m_Node(ISD::CONCAT_VECTORS, sd::m_Value(RHS2),
+                                        sd::m_Value(RHS3)))))
+      return false;
+
+    if (LHS0.getNode() != LHS2.getNode() || LHS0 == LHS2 ||
+        !sd_match(LHS0.getNode(),
+                  sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(LHS0),
+                             sd::m_Value(LHS2))))
+      return false;
+    if (LHS1.getNode() != LHS3.getNode() || LHS1 == LHS3 ||
+        !sd_match(LHS1.getNode(),
+                  sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(LHS1),
+                             sd::m_Value(LHS3))))
+      return false;
+
+    if (RHS0.getNode() != RHS2.getNode() || RHS0 == RHS2 ||
+        !sd_match(RHS0.getNode(),
+                  sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(RHS0),
+                             sd::m_Value(RHS2))))
+      return false;
+    if (RHS1.getNode() != RHS3.getNode() || RHS1 == RHS3 ||
+        !sd_match(RHS1.getNode(),
+                  sd::m_Node(ISD::VECTOR_INTERLEAVE, sd::m_Value(RHS1),
+                             sd::m_Value(RHS3))))
+      return false;
+
+    // After long operation:
+    //   v0 = 8, 0
+    //   v1 = 10, 2
+    //
+    //   v2 = 12, 4
+    //   v3 = 14, 6
+    //
+    //   v4 = 9, 1
+    //   v5 = 11, 3
+    //
+    //   v6 = 13, 5
+    //   v7 = 15, 7
+    auto [V0, V1] = CreateLongOpPair(LHS0, RHS0);
+    auto [V2, V3] = CreateLongOpPair(LHS1, RHS1);
+    auto [V4, V5] = CreateLongOpPair(LHS2, RHS2);
+    auto [V6, V7] = CreateLongOpPair(LHS3, RHS3);
+    EVT WideResVT = V0.getValueType();
+
+    // Now we can interleave and concat:
+    //   i0 = interleave(v0, v4) ; i0 = [(1, 0), (12, 8)]
+    //   i1 = interleave(v1, v5) ; i1 = [(3, 2), (11, 10)]
+    //   i2 = interleave(v2, v6) ; i2 = [(5, 4), (13, 12)]
+    //   i3 = interleave(v3, v7) ; i3 = [(7, 6), (15, 14)]
+    //   res = concat(i0[0], i1[0]...i0[1], i1[1]...)
+    SDValue Interleave0 =
+        DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
+                    DAG.getVTList(WideResVT, WideResVT), V0, V4);
+    SDValue Interleave1 =
+        DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
+                    DAG.getVTList(WideResVT, WideResVT), V1, V5);
+    SDValue Interleave2 =
+        DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
+                    DAG.getVTList(WideResVT, WideResVT), V2, V6);
+    SDValue Interleave3 =
+        DAG.getNode(ISD::VECTOR_INTERLEAVE, DL,
+                    DAG.getVTList(WideResVT, WideResVT), V3, V7);
+
+    SDValue Concat0 = DAG.getNode(
+        ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
+        Interleave0.getValue(0), Interleave1.getValue(0));
+    SDValue Concat1 = DAG.getNode(
+        ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
+        Interleave2.getValue(0), Interleave3.getValue(0));
+    SDValue Concat2 = DAG.getNode(
+        ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
+        Interleave0.getValue(1), Interleave1.getValue(1));
+    SDValue Concat3 = DAG.getNode(
+        ISD::CONCAT_VECTORS, DL, WideResVT.getDoubleNumVectorElementsVT(Ctx),
+        Interleave2.getValue(1), Interleave3.getValue(1));
+    Concat0 =
+        DAG.getNode(ISD::CONCAT_VECTORS, DL,
+                    Concat0.getValueType().getDoubleNumVectorElementsVT(Ctx),
+                    Concat0, Concat1);
+    Concat2 =
+        DAG.getNode(ISD::CONCAT_VECTORS, DL,
+                    Concat2.getValueType().getDoubleNumVectorElementsVT(Ctx),
+                    Concat2, Concat3);
+    Concat0 =
+        DAG.getNode(ISD::CONCAT_VECTORS, DL,
+                    Concat0.getValueType().getDoubleNumVectorElementsVT(Ctx),
+                    Concat0, Concat2);
+
+    Results.push_back(DAG.getZExtOrTrunc(Concat0, DL, VT));
+    return true;
+  }
+
+  return false;
+}
+
 static void ReplaceAddWithADDP(SDNode *N, SmallVectorImpl<SDValue> &Results,
                                SelectionDAG &DAG,
                                const AArch64Subtarget *Subtarget) {
@@ -25429,6 +25718,9 @@ void AArch64TargetLowering::ReplaceNodeResults(
     return;
   case ISD::ADD:
   case ISD::FADD:
+    if (replaceIntOpWithSVE2LongOp(N, Results, DAG, Subtarget))
+      return;
+
     ReplaceAddWithADDP(N, Results, DAG, Subtarget);
     return;
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index db6e8a00d2fb5e..25f40b553b74f8 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -220,6 +220,9 @@ enum NodeType : unsigned {
   URSHR_I,
   URSHR_I_PRED,
 
+  UADDLB,
+  UADDLT,
+
   // Vector narrowing shift by immediate (bottom)
   RSHRNB_I,
 
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 6972acd985cb9a..8f592cf0a5a3b5 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -3556,6 +3556,22 @@ let Predicates = [HasSVE2orSME, UseExperimentalZeroingPseudos] in {
   defm SQSHLU_ZPZI : sve_int_bin_pred_shift_imm_left_zeroing_bhsd<int_aarch64_sve_sqshlu>;
 } // End HasSVE2orSME, UseExperimentalZeroingPseudos
 
+def SDT_AArch64ArithLong_Unpred : SDTypeProfile<1, 2, [
+  SDTCisVec<0>, SDTCisVec<1>, SDTCisSameAs<1,2>,
+  SDTCisInt<0>, SDTCisInt<1>,
+  SDTCisOpSmallerThanOp<1, 0>
+]>;
+def AArch64uaddlb_node : SDNode<"AArch64ISD::UADDLB",  SDT_AArch64ArithLong_Unpred>;
+def AArch64uaddlt_node : SDNode<"AArch64ISD::UADDLT",  SDT_AArch64ArithLong_Unpred>;
+
+// TODO: lower the intrinsic to the isd node.
+def AArch64uaddlb : PatFrags<(ops node:$op1, node:$op2),
+                           [(int_aarch64_sve_uaddlb node:$op1, node:$op2),
+                            (AArch64uaddlb_node node:$op1, node:$op2)]>;
+def AArch64uaddlt : PatFrags<(ops node:$op1, node:$op2),
+                           [(int_aarch64_sve_uaddlt node:$op1, node:$op2),
+                            (AArch64uaddlt_node node:$op1, node:$op2)]>;
+
 let Predicates = [HasSVE2orSME] in {
   // SVE2 predicated shifts
   defm SQSHL_ZPmI  : sve_int_bin_pred_shift_imm_left_dup<0b0110, "sqshl",  "SQSHL_ZPZI",  int_aarch64_sve_sqshl>;
@@ -3567,8 +3583,8 @@ let Predicates = [HasSVE2orSME] in {
   // SVE2 integer add/subtract long
   defm SADDLB_ZZZ : sve2_wide_int_arith_long<0b00000, "saddlb", int_aarch64_sve_saddlb>;
   defm SADDLT_ZZZ : sve2_wide_int_arith_long<0b00001, "saddlt", int_aarch64_sve_saddlt>;
-  defm UADDLB_ZZZ : sve2_wide_int_arith_long<0b00010, "uaddlb", int_aarch64_sve_uaddlb>;
-  defm UADDLT_ZZZ : sve2_wide_int_arith_long<0b00011, "uaddlt", int_aarch64_sve_uaddlt>;
+  defm UADDLB_ZZZ : sve2_wide_int_arith_long<0b00010, "uaddlb", AArch64uaddlb>;
+  defm UADDLT_ZZZ : sve2_wide_int_arith_long<0b00011, "uaddlt", AArch64uaddlt>;
   defm SSUBLB_ZZZ : sve2_wide_int_arith_long<0b00100, "ssublb", int_aarch64_sve_ssublb>;
   defm SSUBLT_ZZZ : sve2_wide_int_arith_long<0b00101, "ssublt", int_aarch64_sve_ssublt>;
   defm USUBLB_ZZZ : sve2_wide_int_arith_long<0b00110, "usublb", int_aarch64_sve_usublb>;
diff --git a/llvm/test/CodeGen/AArch64/sve-doublereduct.ll b/llvm/test/CodeGen/AArch64/sve-doublereduct.ll
index 7bc31d44bb6547..6779a43738ce6d 100644
--- a/llvm/test/CodeGen/AArch64/sve-doublereduct.ll
+++ b/llvm/test/CodeGen/AArch64/sve-doublereduct.ll
@@ -126,17 +126,15 @@ define i16 @add_ext_i16(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
 define i16 @add_ext_v32i16(<vscale x 32 x i8> %a, <vscale x 16 x i8> %b) {
 ; CHECK-LABEL: add_ext_v32i16:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    uunpklo z3.h, z1.b
-; CHECK-NEXT:    uunpklo z4.h, z0.b
-; CHECK-NEXT:    uunpkhi z1.h, z1.b
-; CHECK-NEXT:    uunpkhi z0.h, z0.b
-; CHECK-NEXT:    uunpkhi z5.h, z2.b
+; CHECK-NEXT:    uaddlt z3.h, z0.b, z1.b
+; CHECK-NEXT:    uaddlb z0.h, z0.b, z1.b
+; CHECK-NEXT:    uunpkhi z1.h, z2.b
 ; CHECK-NEXT:    uunpklo z2.h, z2.b
 ; CHECK-NEXT:    ptrue p0.h
-; CHECK-NEXT:    add z0.h, z0.h, z1.h
-; CHECK-NEXT:    add z1.h, z4.h, z3.h
-; CHECK-NEXT:    add z0.h, z1.h, z0.h
-; CHECK-NEXT:    add z1.h, z2.h, z5.h
+; CHECK-NEXT:    zip2 z4.h, z0.h, z3.h
+; CHECK-NEXT:    zip1 z0.h, z0.h, z3.h
+; CHECK-NEXT:    add z1.h, z2.h, z1.h
+; CHECK-NEXT:    add z0.h, z0.h, z4.h
 ; CHECK-NEXT:    add z0.h, z0.h, z1.h
 ; CHECK-NEXT:    uaddv d0, p0, z0.h
 ; CHECK-NEXT:    fmov x0, d0
diff --git a/llvm/test/CodeGen/AArch64/sve2-uaddl.ll b/llvm/test/CodeGen/AArch64/sve2-uaddl.ll
new file mode 100644
index 00000000000000..caca0db65839d3
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve2-uaddl.ll
@@ -0,0 +1,636 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-unknown-linux -mattr=+sve -o - %s | FileCheck --check-prefix=SVE %s
+; RUN: llc -mtriple=aarch64-unknown-linux -mattr=+sve2 -o - %s | FileCheck --check-prefix=SVE2 %s
+
+define <vscale x 16 x i16> @foo_noloadSt_scalable_16x8to16x16(
+; SVE-LABEL: foo_noloadSt_scalable_16x8to16x16:
+; SVE:       // %bb.0:
+; SVE-NEXT:    uunpkhi z2.h, z0.b
+; SVE-NEXT:    uunpklo z0.h, z0.b
+; SVE-NEXT:    uunpkhi z3.h, z1.b
+; SVE-NEXT:    uunpklo z1.h, z1.b
+; SVE-NEXT:    add z0.h, z1.h, z0.h
+; SVE-NEXT:    add z1.h, z3.h, z2.h
+; SVE-NEXT:    ret
+;
+; SVE2-LABEL: foo_noloadSt_scalable_16x8to16x16:
+; SVE2:       // %bb.0:
+; SVE2-NEXT:    uaddlt z2.h, z1.b, z0.b
+; SVE2-NEXT:    uaddlb z1.h, z1.b, z0.b
+; SVE2-NEXT:    zip1 z0.h, z1.h, z2.h
+; SVE2-NEXT:    zip2 z1.h, z1.h, z2.h
+; SVE2-NEXT:    ret
+    <vscale x 16 x i8> %A,
+    <vscale x 16 x i8> %B
+    ) {
+  %1 = zext <vscale x 16 x i8> %A to <vscale x 16 x i16>
+  %2 = zext <vscale x 16 x i8> %B to <vscale x 16 x i16>
+  %add1 = add nuw nsw <vscale x 16 x i16> %2, %1
+  ret <vscale x 16 x i16> %add1
+}
+
+define <vscale x 16 x i32> @foo_noloadSt_scalable_16x8to16x32(
+; SVE-LABEL: foo_noloadSt_scalable_16x8to16x32:
+; SVE:       // %bb.0:
+; SVE-NEXT:    uunpklo z4.h, z1.b
+; SVE-NEXT:    uunpklo z5.h, z0.b
+; SVE-NEXT:    uunpkhi z0.h, z0.b
+; SVE-NEXT:    uunpkhi z1.h, z1.b
+; SVE-NEXT:    uunpklo z26.s, z2.h
+; SVE-NEXT:    uunpkhi z2.s, z2.h
+; SVE-NEXT:    uunpklo z6.s, z5.h
+; SVE-NEXT:    uunpklo z7.s, z4.h
+; SVE-NEXT:    uunpkhi z5.s, z5.h
+; SVE-NEXT:    uunpklo z24.s, z0.h
+; SVE-NEXT:    uunpkhi z0.s, z0.h
+; SVE-NEXT:    uunpkhi z4.s, z4.h
+; SVE-NEXT:    uunpklo z25.s, z1.h
+; SVE-NEXT:    uunpkhi z1.s, z1.h
+; SVE-NEXT:    add z6.s, z7.s, z6.s
+; SVE-NEXT:    uunpkhi z7.s, z3.h
+; SVE-NEXT:    uunpklo z3.s, z3.h
+; SVE-NEXT:    add z27.s, z1.s, z0.s
+; SVE-NEXT:    add z24.s, z25.s, z24.s...
[truncated]

@efriedma-quic efriedma-quic requested review from topperc and fhahn April 18, 2024 22:19
@efriedma-quic
Copy link
Collaborator

Given the constraint of doing everything in SelectionDAG, this approach is probably the best you can do.


I think longer-term it will be important to do all the interleaving-related analysis together. We need to worry about the interaction between interleaving add/sub, interleaving load/store, and pure element-wise operations that we want to invoke using deinterleaved operands.

Given that, we might actually want to explicitly represent the deinterleaving operation on IR: teach the vectorizer to generate a deinterleave+zero-extend instead of a double-width zero-extend, or something like that.

(See also current work on interleaved load/store in #89018 etc.)

@davemgreen
Copy link
Collaborator

I have looked at top/bottom instructions in the context of MVE in the past, and ended up writing the MVELaneInterleavingPass. MVE has some features that make lane interleaving more important, where it doesn't have a standard sext/zext instruction (or doesn't have a concat, depending on how you look at it) and can fall back to storing+reloading from the stack to do an extend at times.

The MVELaneInterleavingPass operates at the IR level so that it can work across basic blocks, and only makes the transform if the whole code can profitably be transformed. That might be different to SVE if there enough instructions to make local transforms profitable on their own. I think in the long run it would be nice if the vectorizer could properly reason about them, so that it could do a better job costing them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants