Skip to content

[llvm] Fix crash when complex deinterleaving operates on an unrolled loop #129735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

NickGuy-Arm
Copy link
Contributor

@NickGuy-Arm NickGuy-Arm commented Mar 4, 2025

When attempting to perform complex deinterleaving on an unrolled loop containing a reduction, the complex deinterleaving pass would fail to accommodate the wider types when accumulating the unrolled paths. Instead of trying to alter the incoming IR to fit expectations, the pass should instead decide against processing any reduction that results in a non-complex or non-vector value.

@NickGuy-Arm NickGuy-Arm force-pushed the complex-deinterleaving-unroll-crash branch from 12760e6 to 3ed40f4 Compare March 4, 2025 17:10
@NickGuy-Arm NickGuy-Arm marked this pull request as ready for review March 4, 2025 17:11
@llvmbot
Copy link
Member

llvmbot commented Mar 4, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Nicholas Guy (NickGuy-Arm)

Changes

When attempting to perform complex deinterleaving on an unrolled loop containing a reduction, the complex deinterleaving pass would fail to accommodate the wider types when accumulating the unrolled paths.


Patch is 22.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129735.diff

2 Files Affected:

  • (modified) llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (+40-6)
  • (added) llvm/test/CodeGen/AArch64/complex-deinterleaving-unrolled-cdot.ll (+181)
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 92053ed561901..e1e0961874b1b 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -61,6 +61,7 @@
 
 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
@@ -274,6 +275,13 @@ class ComplexDeinterleavingGraph {
   /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
   MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
 
+  /// In the case of reductions in unrolled loops, the %OutsideUser from
+  /// ReductionInfo is an add instruction that precedes the reduction.
+  /// UnrollInfo pairs values together if they are both operands of the same
+  /// add. This pairing info is then used to add the resulting complex
+  /// operations together before the final reduction.
+  MapVector<Value *, Value *> UnrollInfo;
+
   /// In the process of detecting a reduction, we consider a pair of
   /// %ReductionOP, which we refer to as real and imag (or vice versa), and
   /// traverse the use-tree to detect complex operations. As this is a reduction
@@ -2253,8 +2261,31 @@ void ComplexDeinterleavingGraph::processReductionSingle(
   auto *FinalReduction = ReductionInfo[Real].second;
   Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
 
-  auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
+  Value *Other;
+  bool EraseFinalReductionHere = false;
+  if (match(FinalReduction, m_c_Add(m_Specific(Real), m_Value(Other)))) {
+    UnrollInfo[Real] = OperationReplacement;
+    if (!UnrollInfo.contains(Other) || !FinalReduction->hasOneUser())
+      return;
+
+    auto *User = *FinalReduction->user_begin();
+    if (!match(User, m_Intrinsic<Intrinsic::vector_reduce_add>()))
+      return;
+
+    FinalReduction = cast<Instruction>(User);
+    Builder.SetInsertPoint(FinalReduction);
+    OperationReplacement =
+        Builder.CreateAdd(OperationReplacement, UnrollInfo[Other]);
+
+    UnrollInfo.erase(Real);
+    UnrollInfo.erase(Other);
+    EraseFinalReductionHere = true;
+  }
+
+  Value *AddReduce = Builder.CreateAddReduce(OperationReplacement);
   FinalReduction->replaceAllUsesWith(AddReduce);
+  if (EraseFinalReductionHere)
+    FinalReduction->eraseFromParent();
 }
 
 void ComplexDeinterleavingGraph::processReductionOperation(
@@ -2299,7 +2330,7 @@ void ComplexDeinterleavingGraph::processReductionOperation(
 }
 
 void ComplexDeinterleavingGraph::replaceNodes() {
-  SmallVector<Instruction *, 16> DeadInstrRoots;
+  SmallSetVector<Instruction *, 16> DeadInstrRoots;
   for (auto *RootInstruction : OrderedRoots) {
     // Check if this potential root went through check process and we can
     // deinterleave it
@@ -2316,20 +2347,23 @@ void ComplexDeinterleavingGraph::replaceNodes() {
       auto *RootImag = cast<Instruction>(RootNode->Imag);
       ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
-      DeadInstrRoots.push_back(RootReal);
-      DeadInstrRoots.push_back(RootImag);
+      DeadInstrRoots.insert(RootReal);
+      DeadInstrRoots.insert(RootImag);
     } else if (RootNode->Operation ==
                ComplexDeinterleavingOperation::ReductionSingle) {
       auto *RootInst = cast<Instruction>(RootNode->Real);
       ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
-      DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
+      DeadInstrRoots.insert(ReductionInfo[RootInst].second);
     } else {
       assert(R && "Unable to find replacement for RootInstruction");
-      DeadInstrRoots.push_back(RootInstruction);
+      DeadInstrRoots.insert(RootInstruction);
       RootInstruction->replaceAllUsesWith(R);
     }
   }
 
+  assert(UnrollInfo.empty() &&
+         "UnrollInfo should be empty after replacing all nodes");
+
   for (auto *I : DeadInstrRoots)
     RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
 }
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-unrolled-cdot.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-unrolled-cdot.ll
new file mode 100644
index 0000000000000..e680fd883a1ac
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-unrolled-cdot.ll
@@ -0,0 +1,181 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S --passes=complex-deinterleaving %s --mattr=+sve2 -o - | FileCheck %s --check-prefix=CHECK-SVE2
+; RUN: opt -S --passes=complex-deinterleaving %s --mattr=+sve -o - | FileCheck %s --check-prefix=CHECK-SVE
+; RUN: opt -S --passes=complex-deinterleaving %s -o - | FileCheck %s --check-prefix=CHECK-NOSVE
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a0, <vscale x 32 x i8> %b0, <vscale x 32 x i8> %a1, <vscale x 32 x i8> %b1) {
+; CHECK-SVE2-LABEL: define i32 @cdotp_i8_rot0(
+; CHECK-SVE2-SAME: <vscale x 32 x i8> [[A0:%.*]], <vscale x 32 x i8> [[B0:%.*]], <vscale x 32 x i8> [[A1:%.*]], <vscale x 32 x i8> [[B1:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-SVE2-NEXT:  [[ENTRY:.*]]:
+; CHECK-SVE2-NEXT:    br label %[[VECTOR_BODY:.*]]
+; CHECK-SVE2:       [[VECTOR_BODY]]:
+; CHECK-SVE2-NEXT:    [[TMP0:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[TMP11:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-SVE2-NEXT:    [[TMP1:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[TMP21:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-SVE2-NEXT:    [[TMP2:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A0]], i64 0)
+; CHECK-SVE2-NEXT:    [[TMP3:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B0]], i64 0)
+; CHECK-SVE2-NEXT:    [[TMP4:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A0]], i64 16)
+; CHECK-SVE2-NEXT:    [[TMP5:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B0]], i64 16)
+; CHECK-SVE2-NEXT:    [[TMP6:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP0]], i64 0)
+; CHECK-SVE2-NEXT:    [[TMP7:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP0]], i64 4)
+; CHECK-SVE2-NEXT:    [[TMP8:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP6]], <vscale x 16 x i8> [[TMP2]], <vscale x 16 x i8> [[TMP3]], i32 0)
+; CHECK-SVE2-NEXT:    [[TMP9:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP7]], <vscale x 16 x i8> [[TMP4]], <vscale x 16 x i8> [[TMP5]], i32 0)
+; CHECK-SVE2-NEXT:    [[TMP10:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> [[TMP8]], i64 0)
+; CHECK-SVE2-NEXT:    [[TMP11]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> [[TMP10]], <vscale x 4 x i32> [[TMP9]], i64 4)
+; CHECK-SVE2-NEXT:    [[TMP12:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A1]], i64 0)
+; CHECK-SVE2-NEXT:    [[TMP13:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B1]], i64 0)
+; CHECK-SVE2-NEXT:    [[TMP14:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A1]], i64 16)
+; CHECK-SVE2-NEXT:    [[TMP15:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B1]], i64 16)
+; CHECK-SVE2-NEXT:    [[TMP16:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP1]], i64 0)
+; CHECK-SVE2-NEXT:    [[TMP17:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP1]], i64 4)
+; CHECK-SVE2-NEXT:    [[TMP18:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP16]], <vscale x 16 x i8> [[TMP12]], <vscale x 16 x i8> [[TMP13]], i32 0)
+; CHECK-SVE2-NEXT:    [[TMP19:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP17]], <vscale x 16 x i8> [[TMP14]], <vscale x 16 x i8> [[TMP15]], i32 0)
+; CHECK-SVE2-NEXT:    [[TMP20:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> [[TMP18]], i64 0)
+; CHECK-SVE2-NEXT:    [[TMP21]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> [[TMP20]], <vscale x 4 x i32> [[TMP19]], i64 4)
+; CHECK-SVE2-NEXT:    br i1 true, label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]]
+; CHECK-SVE2:       [[MIDDLE_BLOCK]]:
+; CHECK-SVE2-NEXT:    [[TMP22:%.*]] = add <vscale x 8 x i32> [[TMP21]], [[TMP11]]
+; CHECK-SVE2-NEXT:    [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP22]])
+; CHECK-SVE2-NEXT:    ret i32 [[TMP23]]
+;
+; CHECK-SVE-LABEL: define i32 @cdotp_i8_rot0(
+; CHECK-SVE-SAME: <vscale x 32 x i8> [[A0:%.*]], <vscale x 32 x i8> [[B0:%.*]], <vscale x 32 x i8> [[A1:%.*]], <vscale x 32 x i8> [[B1:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-SVE-NEXT:  [[ENTRY:.*]]:
+; CHECK-SVE-NEXT:    br label %[[VECTOR_BODY:.*]]
+; CHECK-SVE:       [[VECTOR_BODY]]:
+; CHECK-SVE-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[PARTIAL_REDUCE33:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[VEC_PHI25:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[PARTIAL_REDUCE34:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT:    [[A0_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[A0]])
+; CHECK-SVE-NEXT:    [[A0_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A0_DEINTERLEAVED]], 0
+; CHECK-SVE-NEXT:    [[A0_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A0_DEINTERLEAVED]], 1
+; CHECK-SVE-NEXT:    [[A1_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[A1]])
+; CHECK-SVE-NEXT:    [[A1_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A1_DEINTERLEAVED]], 0
+; CHECK-SVE-NEXT:    [[A1_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A1_DEINTERLEAVED]], 1
+; CHECK-SVE-NEXT:    [[A0_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[A0_REAL]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT:    [[A1_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[A1_REAL]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT:    [[B0_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[B0]])
+; CHECK-SVE-NEXT:    [[B0_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B0_DEINTERLEAVED]], 0
+; CHECK-SVE-NEXT:    [[B0_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B0_DEINTERLEAVED]], 1
+; CHECK-SVE-NEXT:    [[B1_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[B1]])
+; CHECK-SVE-NEXT:    [[B1_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B1_DEINTERLEAVED]], 0
+; CHECK-SVE-NEXT:    [[B1_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B1_DEINTERLEAVED]], 1
+; CHECK-SVE-NEXT:    [[B0_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[B0_REAL]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT:    [[B1_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[B1_REAL]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT:    [[TMP0:%.*]] = mul nsw <vscale x 16 x i32> [[B0_REAL_EXT]], [[A0_REAL_EXT]]
+; CHECK-SVE-NEXT:    [[TMP1:%.*]] = mul nsw <vscale x 16 x i32> [[B1_REAL_EXT]], [[A1_REAL_EXT]]
+; CHECK-SVE-NEXT:    [[A0_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[A0_IMAG]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT:    [[A1_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[A1_IMAG]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT:    [[B0_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[B0_IMAG]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT:    [[B1_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[B1_IMAG]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT:    [[TMP2:%.*]] = mul nsw <vscale x 16 x i32> [[B0_IMAG_EXT]], [[A0_IMAG_EXT]]
+; CHECK-SVE-NEXT:    [[TMP3:%.*]] = mul nsw <vscale x 16 x i32> [[B1_IMAG_EXT]], [[A1_IMAG_EXT]]
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP0]])
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE32:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI25]], <vscale x 16 x i32> [[TMP1]])
+; CHECK-SVE-NEXT:    [[TMP4:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP2]]
+; CHECK-SVE-NEXT:    [[TMP5:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP3]]
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE33]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP4]])
+; CHECK-SVE-NEXT:    [[PARTIAL_REDUCE34]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE32]], <vscale x 16 x i32> [[TMP5]])
+; CHECK-SVE-NEXT:    br i1 true, label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]]
+; CHECK-SVE:       [[MIDDLE_BLOCK]]:
+; CHECK-SVE-NEXT:    [[BIN_RDX:%.*]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE34]], [[PARTIAL_REDUCE33]]
+; CHECK-SVE-NEXT:    [[TMP6:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[BIN_RDX]])
+; CHECK-SVE-NEXT:    ret i32 [[TMP6]]
+;
+; CHECK-NOSVE-LABEL: define i32 @cdotp_i8_rot0(
+; CHECK-NOSVE-SAME: <vscale x 32 x i8> [[A0:%.*]], <vscale x 32 x i8> [[B0:%.*]], <vscale x 32 x i8> [[A1:%.*]], <vscale x 32 x i8> [[B1:%.*]]) {
+; CHECK-NOSVE-NEXT:  [[ENTRY:.*]]:
+; CHECK-NOSVE-NEXT:    br label %[[VECTOR_BODY:.*]]
+; CHECK-NOSVE:       [[VECTOR_BODY]]:
+; CHECK-NOSVE-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[PARTIAL_REDUCE33:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NOSVE-NEXT:    [[VEC_PHI25:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[PARTIAL_REDUCE34:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NOSVE-NEXT:    [[A0_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[A0]])
+; CHECK-NOSVE-NEXT:    [[A0_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A0_DEINTERLEAVED]], 0
+; CHECK-NOSVE-NEXT:    [[A0_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A0_DEINTERLEAVED]], 1
+; CHECK-NOSVE-NEXT:    [[A1_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[A1]])
+; CHECK-NOSVE-NEXT:    [[A1_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A1_DEINTERLEAVED]], 0
+; CHECK-NOSVE-NEXT:    [[A1_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A1_DEINTERLEAVED]], 1
+; CHECK-NOSVE-NEXT:    [[A0_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[A0_REAL]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT:    [[A1_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[A1_REAL]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT:    [[B0_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[B0]])
+; CHECK-NOSVE-NEXT:    [[B0_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B0_DEINTERLEAVED]], 0
+; CHECK-NOSVE-NEXT:    [[B0_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B0_DEINTERLEAVED]], 1
+; CHECK-NOSVE-NEXT:    [[B1_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[B1]])
+; CHECK-NOSVE-NEXT:    [[B1_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B1_DEINTERLEAVED]], 0
+; CHECK-NOSVE-NEXT:    [[B1_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B1_DEINTERLEAVED]], 1
+; CHECK-NOSVE-NEXT:    [[B0_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[B0_REAL]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT:    [[B1_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[B1_REAL]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT:    [[TMP0:%.*]] = mul nsw <vscale x 16 x i32> [[B0_REAL_EXT]], [[A0_REAL_EXT]]
+; CHECK-NOSVE-NEXT:    [[TMP1:%.*]] = mul nsw <vscale x 16 x i32> [[B1_REAL_EXT]], [[A1_REAL_EXT]]
+; CHECK-NOSVE-NEXT:    [[A0_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[A0_IMAG]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT:    [[A1_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[A1_IMAG]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT:    [[B0_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[B0_IMAG]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT:    [[B1_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[B1_IMAG]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT:    [[TMP2:%.*]] = mul nsw <vscale x 16 x i32> [[B0_IMAG_EXT]], [[A0_IMAG_EXT]]
+; CHECK-NOSVE-NEXT:    [[TMP3:%.*]] = mul nsw <vscale x 16 x i32> [[B1_IMAG_EXT]], [[A1_IMAG_EXT]]
+; CHECK-NOSVE-NEXT:    [[PARTIAL_REDUCE:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP0]])
+; CHECK-NOSVE-NEXT:    [[PARTIAL_REDUCE32:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI25]], <vscale x 16 x i32> [[TMP1]])
+; CHECK-NOSVE-NEXT:    [[TMP4:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP2]]
+; CHECK-NOSVE-NEXT:    [[TMP5:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP3]]
+; CHECK-NOSVE-NEXT:    [[PARTIAL_REDUCE33]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP4]])
+; CHECK-NOSVE-NEXT:    [[PARTIAL_REDUCE34]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE32]], <vscale x 16 x i32> [[TMP5]])
+; CHECK-NOSVE-NEXT:    br i1 true, label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]]
+; CHECK-NOSVE:       [[MIDDLE_BLOCK]]:
+; CHECK-NOSVE-NEXT:    [[BIN_RDX:%.*]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE34]], [[PARTIAL_REDUCE33]]
+; CHECK-NOSVE-NEXT:    [[TMP6:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[BIN_RDX]])
+; CHECK-NOSVE-NEXT:    ret i32 [[TMP6]]
+;
+entry:
+  br label %vector.body
+
+vector.body:                                      ; preds = %vector.body, %entry
+  %vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce33, %vector.body ]
+  %vec.phi25 = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce34, %vector.body ]
+  %a0.deinterleaved = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %a0)
+  %a0.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a0.deinterleaved, 0
+  %a0.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a0.deinterleaved, 1
+  %a1.deinterleaved = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %a1)
+  %a1.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a1.deinterleaved, 0
+  %a1.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a1.deinterleaved, 1
+  %a0.real.ext = sext <vscale x 16 x i8> %a0.real to <vscale x 16 x i32>
+  %a1.real.ext = sext <vscale x 16 x i8> %a1.real to <vscale x 16 x i32>
+  %b0.deinterleaved = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %b0)
+  %b0.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b0.deinterleaved, 0
+  %b0.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b0.deinterleaved, 1
+  %b1.deinterleaved = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %b1)
+  %b1.real = extractvalue { <vs...
[truncated]

@MacDue MacDue requested a review from igogo-x86 March 5, 2025 10:52
@NickGuy-Arm
Copy link
Contributor Author

Ping

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I was able to spot some issues, it's probably better if someone with more experience in this pass has a look at it as well.

Comment on lines +133 to +134
%vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce33, %vector.body ]
%vec.phi25 = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce34, %vector.body ]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test where the unroll factor is larger than 2, e.g. 4?

auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
Value *Other;
bool EraseFinalReductionHere = false;
if (match(FinalReduction, m_c_Add(m_Specific(Real), m_Value(Other)))) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I replace the add in this test, by a sub, the pass still crashes, so this is not sufficient.
Does it matter what the operation (the one outside the loop) actually is?

I would have expected something like this:

define <vscale x 4 x i32> @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) {
entry:
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %entry                                                                           
  %vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce.sub, %vector.body ]                                                     
  %a.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.v32i8(<vscale x 32 x i8> %a)                               
  %b.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.v32i8(<vscale x 32 x i8> %b)                               
  %a.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 0                                                                    
  %a.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 1                                                                    
  %b.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 0
  %b.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 1                                                                    
  %a.real.ext = sext <vscale x 16 x i8> %a.real to <vscale x 16 x i32>                                                                                     
  %a.imag.ext = sext <vscale x 16 x i8> %a.imag to <vscale x 16 x i32>                                                                                     
  %b.real.ext = sext <vscale x 16 x i8> %b.real to <vscale x 16 x i32>
  %b.imag.ext = sext <vscale x 16 x i8> %b.imag to <vscale x 16 x i32>
  %real.mul = mul <vscale x 16 x i32> %b.real.ext, %a.real.ext
  %real.mul.reduced = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi, <vscale x 16 x i32> %real.mul)
  %imag.mul = mul <vscale x 16 x i32> %b.imag.ext, %a.imag.ext
  %imag.mul.neg = sub <vscale x 16 x i32> zeroinitializer, %imag.mul
  %partial.reduce.sub = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %real.mul.reduced, <vscale x 16 x i32> %imag.mul.neg)
  br i1 true, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  ret <vscale x 4 x i32> %partial.reduce.sub
}

to use cdot instructions as well, but this case also seems to crash. This suggests that the issue is not to do with unrolling, but rather with the user outside the loop being anything else than a reduction?

Copy link
Contributor Author

@NickGuy-Arm NickGuy-Arm Mar 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a separate issue; I wouldn't expect that snippet to be processed as the current implementation would require changing the function return type. Whereas the issue this patch is aimed at fixing is when it tries to change one operand of an add with a value of a different type.

If this were to instead reinterleave and store the complex result in middle.block, instead of returning it, then I would expect the pass to process it and emit cdot instructions.

When I replace the add in this test, by a sub, the pass still crashes, so this is not sufficient.

I'm not sure if the loop vectorizer would ever emit a sub here. Please do correct me if I'm wrong, but I'm not seeing any VECREDUCE_ADD or vecreduce.add equivalent for subtraction, and the instruction of %bin.rdx in this case is derived from the reduction intrinsic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For regular reductions (without cdot), we needed to analyse and rewrite use outside of the loop due to Real and Imaginary part extraction. See cases in complex-deinterleaving-reductions.ll. But for cdot, we don't need to do any of that. Here's a test from complex-deinterleaving-cdot.ll:

define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) {
entry:
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %entry
  %vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce.sub, %vector.body ]
  %a.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %a)
  %b.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %b)
  %a.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 0
  %a.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 1
  %b.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 0
  %b.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 1
  %a.real.ext = sext <vscale x 16 x i8> %a.real to <vscale x 16 x i32>
  %a.imag.ext = sext <vscale x 16 x i8> %a.imag to <vscale x 16 x i32>
  %b.real.ext = sext <vscale x 16 x i8> %b.real to <vscale x 16 x i32>
  %b.imag.ext = sext <vscale x 16 x i8> %b.imag to <vscale x 16 x i32>
  %real.mul = mul <vscale x 16 x i32> %b.real.ext, %a.real.ext
  %real.mul.reduced = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi, <vscale x 16 x i32> %real.mul)
  %imag.mul = mul <vscale x 16 x i32> %b.imag.ext, %a.imag.ext
  %imag.mul.neg = sub <vscale x 16 x i32> zeroinitializer, %imag.mul
  %partial.reduce.sub = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %real.mul.reduced, <vscale x 16 x i32> %imag.mul.neg)
  br i1 true, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  %0 = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %partial.reduce.sub)
  ret i32 %0
}

It is currently transformed into:

efine i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) #0 {
entry:
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %entry
  %0 = phi <vscale x 8 x i32> [ zeroinitializer, %entry ], [ %10, %vector.body ]
  %1 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 0)
  %2 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 0)
  %3 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 16)
  %4 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 16)
  %5 = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> %0, i64 0)
  %6 = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> %0, i64 4)
  %7 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %5, <vscale x 16 x i8> %1, <vscale x 16 x i8> %2, i32 0)
  %8 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %6, <vscale x 16 x i8> %3, <vscale x 16 x i8> %4, i32 0)
  %9 = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> %7, i64 0)
  %10 = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> %9, <vscale x 4 x i32> %8, i64 4)
  br i1 true, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  %11 = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> %10)
  ret i32 %11
}

But instead, we could ignore everything happening after the final llvm.experimental.vector.partial.reduce.add and just put one cdot on another:

define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) #0 {
entry:
  br label %vector.body

vector.body:                                      ; preds = %vector.body, %entry
  %0 = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %8, %vector.body ]
  %1 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 0)
  %2 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 0)
  %3 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 16)
  %4 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 16)
  %7 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %0, <vscale x 16 x i8> %1, <vscale x 16 x i8> %2, i32 0)
  %8 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %7, <vscale x 16 x i8> %3, <vscale x 16 x i8> %4, i32 0)
  br i1 true, label %middle.block, label %vector.body

middle.block:                                     ; preds = %vector.body
  %result = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %8)
  ret i32 %result
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @igogo-x86, I agree it makes more sense to feed the result from one cdot into the other, rather than changing the PHI node to be a wider type.

My understanding is that the Complex Deinterleaving pass was created because for certain operations (like reductions) where there is both a PHI+reduction for the imaginary and one PHI+reduction for the real part, it is better to keep the intermediate values interleaved, because the cmla instruction takes interleaved tuples as input and returns interleaved tuples as output. This avoids having to deinterleave values first and it also allows using specialised cmla instructions to do the complex MLA operation. The reduction PHI then contains a vector of <(r, i), (r, i), ..> tuples, which need de-interleaving only when doing the final reduction.
For the case of cdot instructions there is no need for this, because the result vector will always be deinterleaved (the cdot instruction returns either a widened real, or a widened imaginary result).

If that understanding is correct, then I don't really see a need to implement this optimization in the ComplexDeinterleave pass. This looks more like a DAGcombine of partialreduce(mul(ext(deinterleave(a)), ext(deinterleave(b))) -> cdot(a, b, #0) (with some variation of this pattern for other rotations). With the new ISD node ISD::PARTIAL_REDUCE_[U|S]MLA added by @JamesChesterman this should be even easier to identify.

Please let me know if I'm missing anything here though.

@NickGuy-Arm
Copy link
Contributor Author

NickGuy-Arm commented Mar 18, 2025

Instead of trying to accommodate the unexpected types, I've instead altered this patch to simply not attempt to process reduction loops that would result in a non-complex or non-vector value. In the interest of fixing the crash, this seems like the simplest fix. We can lessen the restrictions in future to accommodate more use cases in the future if necessary.

@igogo-x86
Copy link
Contributor

Sounds good! Maybe it's a good idea to add a TODO to fix that, even if you plan to do it right away?

@NickGuy-Arm NickGuy-Arm merged commit 3f4b2f1 into llvm:main Mar 19, 2025
5 of 9 checks passed
@NickGuy-Arm NickGuy-Arm added this to the LLVM 20.X Release milestone Mar 19, 2025
@github-project-automation github-project-automation bot moved this to Needs Triage in LLVM Release Status Mar 19, 2025
@NickGuy-Arm
Copy link
Contributor Author

/cherry-pick 3f4b2f1

@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2025

/pull-request #132031

@llvmbot llvmbot moved this from Needs Triage to Done in LLVM Release Status Mar 19, 2025
swift-ci pushed a commit to swiftlang/llvm-project that referenced this pull request Mar 25, 2025
…loop (llvm#129735)

When attempting to perform complex deinterleaving on an unrolled loop
containing a reduction, the complex deinterleaving pass would fail to
accommodate the wider types when accumulating the unrolled paths.
Instead of trying to alter the incoming IR to fit expectations, the pass
should instead decide against processing any reduction that results in a
non-complex or non-vector value.

(cherry picked from commit 3f4b2f1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

4 participants