Skip to content

[AArch64][SVE] Add partial reduction SDNodes #117185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

JamesChesterman
Copy link
Contributor

Add the opcode 'ISD::PARTIAL_REDUCE_ADD' and use it when making SDNodes. When the inputs and outputs have types that can allow for lowering to wide add or dot product instruction(s), then convert the corresponding intrinsic to an SDNode. This will allow legalisation, which will be added in a future patch, to be done more easily.

@llvmbot
Copy link
Member

llvmbot commented Nov 21, 2024

@llvm/pr-subscribers-llvm-selectiondag

@llvm/pr-subscribers-backend-aarch64

Author: James Chesterman (JamesChesterman)

Changes

Add the opcode 'ISD::PARTIAL_REDUCE_ADD' and use it when making SDNodes. When the inputs and outputs have types that can allow for lowering to wide add or dot product instruction(s), then convert the corresponding intrinsic to an SDNode. This will allow legalisation, which will be added in a future patch, to be done more easily.


Full diff: https://github.com/llvm/llvm-project/pull/117185.diff

8 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/ISDOpcodes.h (+5)
  • (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+5)
  • (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+16)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+6)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+11-1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h (+1)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp (+3)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+27-29)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 0b6d155b6d161e..7809a1b26dd7cd 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,6 +1451,11 @@ enum NodeType {
   VECREDUCE_UMAX,
   VECREDUCE_UMIN,
 
+  // The `llvm.experimental.vector.partial.reduce.add` intrinsic
+  // Operands: Accumulator, Input
+  // Outputs: Output
+  PARTIAL_REDUCE_ADD,
+
   // The `llvm.experimental.stackmap` intrinsic.
   // Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
   // Outputs: output chain, glue
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 2e3507386df309..5d0a976fbb66a9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1595,6 +1595,11 @@ class SelectionDAG {
   /// the target's desired shift amount type.
   SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
 
+  /// Get a partial reduction SD node for the DAG. This is done when the input
+  /// and output types can be legalised for wide add(s) or dot product(s)
+  SDValue getPartialReduceAddSDNode(SDLoc DL, SDValue Chain, SDValue Acc,
+                                    SDValue Input);
+
   /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
   /// its operands and ReducedTY is the intrinsic's return type.
   SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 677b59e0c8fbeb..6cdf87fd7895c7 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3010,6 +3010,22 @@ class MaskedHistogramSDNode : public MaskedGatherScatterSDNode {
   }
 };
 
+class PartialReduceAddSDNode : public SDNode {
+public:
+  friend class SelectionDAG;
+
+  PartialReduceAddSDNode(const DebugLoc &dl, SDVTList VTs)
+      : SDNode(ISD::PARTIAL_REDUCE_ADD, 0, dl, VTs) {}
+
+  const SDValue &getChain() const { return getOperand(0); }
+  const SDValue &getAcc() const { return getOperand(1); }
+  const SDValue &getInput() const { return getOperand(2); }
+
+  static bool classof(const SDNode *N) {
+    return N->getOpcode() == ISD::PARTIAL_REDUCE_ADD;
+  }
+};
+
 class FPStateAccessSDNode : public MemSDNode {
 public:
   friend class SelectionDAG;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3a8ec3c6105bc0..a7a208b6af6f9a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2452,6 +2452,12 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
   return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
 }
 
+SDValue SelectionDAG::getPartialReduceAddSDNode(SDLoc DL, SDValue Chain,
+                                                SDValue Acc, SDValue Input) {
+  return getNode(ISD::PARTIAL_REDUCE_ADD, DL, Acc.getValueType(), Chain, Acc,
+                 Input);
+}
+
 SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
                                           SDValue Op2) {
   EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 9d729d448502d8..0480d99767bb75 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6415,6 +6415,16 @@ void SelectionDAGBuilder::visitVectorHistogram(const CallInst &I,
   DAG.setRoot(Histogram);
 }
 
+void SelectionDAGBuilder::visitPartialReduceAdd(const CallInst &I,
+                                                unsigned IntrinsicID) {
+  SDLoc dl = getCurSDLoc();
+  SDValue Acc = getValue(I.getOperand(0));
+  SDValue Input = getValue(I.getOperand(1));
+  SDValue Chain = getRoot();
+
+  setValue(&I, DAG.getPartialReduceAddSDNode(dl, Chain, Acc, Input));
+}
+
 void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
                                                        unsigned Intrinsic) {
   assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
@@ -8128,7 +8138,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
   case Intrinsic::experimental_vector_partial_reduce_add: {
 
     if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
-      visitTargetIntrinsic(I, Intrinsic);
+      visitPartialReduceAdd(I, Intrinsic);
       return;
     }
 
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index 3a8dc25e98700e..a9e0c8f1ea10c1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -629,6 +629,7 @@ class SelectionDAGBuilder {
   void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI);
   void visitConvergenceControl(const CallInst &I, unsigned Intrinsic);
   void visitVectorHistogram(const CallInst &I, unsigned IntrinsicID);
+  void visitPartialReduceAdd(const CallInst &, unsigned IntrinsicID);
   void visitVectorExtractLastActive(const CallInst &I, unsigned Intrinsic);
   void visitVPLoad(const VPIntrinsic &VPIntrin, EVT VT,
                    const SmallVectorImpl<SDValue> &OpValues);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 580ff19065557b..8ce03b14bda46c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -567,6 +567,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return "histogram";
 
+  case ISD::PARTIAL_REDUCE_ADD:
+    return "partial_reduce_add";
+
     // Vector Predication
 #define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...)                    \
   case ISD::SDID:                                                              \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7ab3fc06715ec8..6a3fbf3a8b596f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1124,6 +1124,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
   setTargetDAGCombine(
       {ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
 
+  setTargetDAGCombine(ISD::PARTIAL_REDUCE_ADD);
+
   setTargetDAGCombine(ISD::FP_EXTEND);
 
   setTargetDAGCombine(ISD::GlobalAddress);
@@ -21718,26 +21720,21 @@ static SDValue tryCombineWhileLo(SDNode *N,
   return SDValue(N, 0);
 }
 
-SDValue tryLowerPartialReductionToDot(SDNode *N,
+SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
                                       const AArch64Subtarget *Subtarget,
                                       SelectionDAG &DAG) {
 
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
-  bool Scalable = N->getValueType(0).isScalableVector();
+  bool Scalable = PR->getValueType(0).isScalableVector();
   if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
   if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
     return SDValue();
 
-  SDLoc DL(N);
+  SDLoc DL(PR);
 
   // The narrower of the two operands. Used as the accumulator
-  auto NarrowOp = N->getOperand(1);
-  auto MulOp = N->getOperand(2);
+  auto NarrowOp = PR->getAcc();
+  auto MulOp = PR->getInput();
   if (MulOp->getOpcode() != ISD::MUL)
     return SDValue();
 
@@ -21755,7 +21752,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   if (A.getValueType() != B.getValueType())
     return SDValue();
 
-  EVT ReducedType = N->getValueType(0);
+  EVT ReducedType = PR->getValueType(0);
   EVT MulSrcType = A.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
@@ -21774,7 +21771,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
-    bool Scalable = N->getValueType(0).isScalableVT();
+    bool Scalable = PR->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
     if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
       return SDValue();
@@ -21805,22 +21802,17 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 }
 
-SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
+SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
                                           const AArch64Subtarget *Subtarget,
                                           SelectionDAG &DAG) {
 
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
-         getIntrinsicID(N) ==
-             Intrinsic::experimental_vector_partial_reduce_add &&
-         "Expected a partial reduction node");
-
   if (!Subtarget->isSVEorStreamingSVEAvailable())
     return SDValue();
 
-  SDLoc DL(N);
+  SDLoc DL(PR);
 
-  auto Acc = N->getOperand(1);
-  auto ExtInput = N->getOperand(2);
+  auto Acc = PR->getAcc();
+  auto ExtInput = PR->getInput();
 
   EVT AccVT = Acc.getValueType();
   EVT AccElemVT = AccVT.getVectorElementType();
@@ -21847,6 +21839,18 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
   return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
 }
 
+static SDValue
+performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+                               const AArch64Subtarget *Subtarget) {
+  auto *PR = cast<PartialReduceAddSDNode>(N);
+  if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
+    return Dot;
+  if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
+    return WideAdd;
+  return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
+                                 PR->getInput());
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -21855,14 +21859,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
   switch (IID) {
   default:
     break;
-  case Intrinsic::experimental_vector_partial_reduce_add: {
-    if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
-      return Dot;
-    if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
-      return WideAdd;
-    return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
-                                   N->getOperand(1), N->getOperand(2));
-  }
   case Intrinsic::aarch64_neon_vcvtfxs2fp:
   case Intrinsic::aarch64_neon_vcvtfxu2fp:
     return tryCombineFixedPointConvert(N, DCI, DAG);
@@ -26148,6 +26144,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::MSCATTER:
   case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
     return performMaskedGatherScatterCombine(N, DCI, DAG);
+  case ISD::PARTIAL_REDUCE_ADD:
+    return performPartialReduceAddCombine(N, DAG, Subtarget);
   case ISD::FP_EXTEND:
     return performFPExtendCombine(N, DAG, DCI, Subtarget);
   case AArch64ISD::BRCOND:

Copy link

github-actions bot commented Nov 28, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@paulwalker-arm
Copy link
Collaborator

I've not had chance to look at the PR yet but from a high level I think having an ISD node that matches the definition of the intrinsic this closely is going to make legalisation and selection harder than necessary. I was expecting the ISD node(s) to be signed operations because then you open the possibility of the operands having different element types which will make it much easier to legalise efficiently as well as providing a route to potentially isel directly.

Specifically, I'm suggesting you follow a similar idiom as used by other nodes where implicitly extension is beneficial (e.g. MULH, ABD) and implement PARTIAL_VECREDUCE_SADD and PARTIAL_VECREDUCE_UADD.

Comment on lines 21829 to 21834
if (ISD::isExtOpcode(InputOpcode)) {
Input = Input.getOperand(0);
if (InputOpcode == ISD::SIGN_EXTEND)
Opcode = ISD::PARTIAL_REDUCE_SADD;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, but it looks like no node of ISD::PARTIAL_REDUCE_SADD is made anymore?

I think what Paul meant was there would be dag-combine that goes from ISD::PARTIAL_REDUCE_UADD to ISD::PARTIAL_REDUCE_SADD (by checking the inputs). Then later where would be a lowering for ISD::PARTIAL_REDUCE_UADD and ISD::PARTIAL_REDUCE_SADD.

Copy link
Contributor Author

@JamesChesterman JamesChesterman Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes currently there is no node generated for ISD::PARTIAL_REDUCE_SADD. I'm going to change it so the DAG combine just removes any extends and changes the opcode to ISD::PARTIAL_REDUCE_SADD if need be. On the dot product side, I think I will keep lowering for usdot in performPartialReduceAddCombine? Just because removing the exts would remove the information of whether each operand is signed. But not removing the exts would make it so splitting would happen in legalisation too many times.

Copy link
Contributor Author

@JamesChesterman JamesChesterman Dec 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've now made it so the DAG-combine function decides whether to keep the node opcode as ISD::PARTIAL_REDUCE_UADD or whether to convert it to ISD::PARTIAL_REDUCE_SADD. But, the DAG-combine function also includes the lowering for usdot IR patterns (due to the reasons described in the above comment). Also, the DAG-combine needed to include lowering for (nx)v16i8 to (nx)v4i64, because otherwise the instruction would be attempted to be legalised. This would be an issue because this case can just be done with an i32 dot product followed by an extension. But, there is now a lowering function called by LowerOperation which covers the general cases for lowering to wide add / dot product instructions.

@paulwalker-arm
Copy link
Collaborator

Perhaps a left field suggestion but looking at the use cases you want to handle I'm wondering if my initial suggestion was too limiting and instead we'd be better of with something more powerful like PARTIAL_REDUCE_SMLA/PARTIAL_REDUCE_UMLA. That way you'll have an easier time representing the expected use cases of the intrinsic whilst still being able to represent the most general form by using a unit vector for one of the multiplicands.

This will re-raise an earlier question of yours as to whether we need PARTIAL_REDUCE_USMLA. To that I would personally hold off and maintain the current DAG combine because it feels somewhat target specific, but at the same time I wouldn't object to it if the consensus is in favour of it.

@JamesChesterman JamesChesterman marked this pull request as ready for review January 10, 2025 10:10
Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no changes to any of the tests, is that expected?

@JamesChesterman
Copy link
Contributor Author

There are no changes to any of the tests, is that expected?

Yes this is expected. This patch is to change around the design of the ISD node to make it easier to use now as well as making it easier to use in future scenarios.

Add the opcode 'ISD::PARTIAL_REDUCE_ADD' and use it when making
SDNodes. When the inputs and outputs have types that can allow for
lowering to wide add or dot product instruction(s), then convert
the corresponding intrinsic to an SDNode. This will allow
legalisation, which will be added in a future patch, to be done
more easily.
as well as changing how the intrinsic is transformed into the SD
node.
Separate lowering code from all being in the DAG-combine function.
Now the DAG-combine decides whether the node should be the signed
or unsigned version of partial reduce add. Then there is a function
in LowerOperation that does the actual lowering to wide adds or dot
products if it is able to.
Add condition in wide add combine to not allow fixed length
vectors.
ISD::PARTIAL_REDUCE_S/UMLA

This makes the lowering function easier as you do not need to worry
about whether the MUL is lowered or not. Instead its operands are
taken from it. If there is no MUL instruction and just one operand,
the other operand is a vector of ones (for value types eligible for
wide add lowering).
Only do it if Input2 is a splat vector of constant 1s. Still create
the MUL in the DAG combine for the wide add pattern. This is
because it is pruned if an operand is constant 1s, or changed to
a shift instruction if an operand is a power of 2. This would not
happen if the MUL was made in LowerPARTIAL_REDUCE_MLA.
operation as Input1.

Also change the LangRef in ISDOpcodes.h for PARTIAL_REDUCE_MLA
nodes to set restrictions on what can be used for its inputs.
Rename functions to accord to PARTIAL_REDUCE_MLA rather than
PARTIAL_REDUCE_ADD.
Comment on lines +22036 to +22037
if ((Op2Opcode != ISD::SPLAT_VECTOR && Op2Opcode != ISD::BUILD_VECTOR) ||
!isOneConstant(Op2->getOperand(0)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use ISD::isConstantSplatVector instead, and then check that the APInt is one using isOne().

Comment on lines +22182 to +22186
if (auto MLA = tryCombinePartialReduceMLAMulOp(Op0, Op1, Op2, DAG, DL)) {
Op0 = MLA->getOperand(0);
Op1 = MLA->getOperand(1);
Op2 = MLA->getOperand(2);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (auto MLA = tryCombinePartialReduceMLAMulOp(Op0, Op1, Op2, DAG, DL)) {
Op0 = MLA->getOperand(0);
Op1 = MLA->getOperand(1);
Op2 = MLA->getOperand(2);
}
if (auto MLA = tryCombinePartialReduceMLAMulOp(Op0, Op1, Op2, DAG, DL))
return MLA;

This should return MLA here. It will be revisited and if need be further optimised by the cases below.

return WideAdd;
// N->getOperand needs calling again because the Op variables may have been
// changed by the functions above
return DAG.expandPartialReduceMLA(DL, N->getOperand(0), N->getOperand(1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default should not be to expand the reduction here. This is a target-specific DAGCombine that tries to optimise the DAG for AArch64-specific use-cases. Expansion should only happen if the optimized DAG cannot be lowered or legalized.

v4i32 PARTIAL_REDUCE_UMLA(v4i32 Acc, v8i32 SEXT(v8i16 X), v8i32 SEXT(v8i16 Y))
->
v4i32 PARTIAL_REDUCE_SMLA(v4i32 Acc, v8i16 X, v8i16 Y)

In the case where the extends cannot be recognised, e.g.

v4i32 PARTIAL_REDUCE_UMLA(v4i32 Acc, v8i32 X, v8i32 Y)

Then this would require type legalisation (splitting, to break up the v8i32 -> 2 x v4i32), which for now could fall back to expandPartialReduce.

The case that can't be represented with these new nodes are the USDOT instructions. Those you could lower in this function to a custom AArch64ISD node.

Comment on lines +22021 to +22033
SDValue ExtMulOpLHS = Op1->getOperand(0);
SDValue ExtMulOpRHS = Op1->getOperand(1);
unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
return SDValue();

SDLoc DL(N);
SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
EVT MulOpLHSVT = MulOpLHS.getValueType();
if (MulOpLHSVT != MulOpRHS.getValueType())
return SDValue();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the way you've currently written this function, my understanding is that this code can be removed, because this code does not try to fold any of the extends into the PARTIAL_REDUCE operation. i.e. al it is doing is:

PARTIAL_REDUCE_UMLA(acc, MUL(lhs, rhs), SPLAT(1))
->
PARTIAL_REDUCE_UMLA(acc, lhs, rhs)

@@ -1451,6 +1451,21 @@ enum NodeType {
VECREDUCE_UMAX,
VECREDUCE_UMIN,

// Partial Reduction nodes. These represent multiply-add instructions because
// Input1 and Input2 are multiplied together first. This result is then
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the behaviour of this node is that Input1 and Input2 are first sign- (for SMLA) or zero- (for UMLA) extended to be the same element type as Acc.

nit: Input1 and Input2 are not defined yet. What about starting with a description of the node with operands, e.g. PARTIAL_REDUCE_*MLA(Acc, Input1, Input2) ?

if (Op0ElemVT != Op1ElemVT || Op1.getOpcode() == AArch64ISD::USDOT)
return SDValue(N, 0);

if (auto MLA = tryCombinePartialReduceMLAMulOp(Op0, Op1, Op2, DAG, DL)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a generically useful combine, so it can be moved to the generic combine code in DAGCombiner.cpp

@JamesChesterman JamesChesterman marked this pull request as draft January 29, 2025 17:01
@JamesChesterman
Copy link
Contributor Author

This work is being split into separate Pull Requests and is being done differently.
See here for first PR:
#125207

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 llvm:SelectionDAG SelectionDAGISel as well
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants