-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[llvm] Fix crash when complex deinterleaving operates on an unrolled loop #129735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[llvm] Fix crash when complex deinterleaving operates on an unrolled loop #129735
Conversation
12760e6
to
3ed40f4
Compare
@llvm/pr-subscribers-backend-aarch64 Author: Nicholas Guy (NickGuy-Arm) ChangesWhen attempting to perform complex deinterleaving on an unrolled loop containing a reduction, the complex deinterleaving pass would fail to accommodate the wider types when accumulating the unrolled paths. Patch is 22.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129735.diff 2 Files Affected:
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 92053ed561901..e1e0961874b1b 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -61,6 +61,7 @@
#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -274,6 +275,13 @@ class ComplexDeinterleavingGraph {
/// `llvm.vector.reduce.fadd` when unroll factor isn't one.
MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
+ /// In the case of reductions in unrolled loops, the %OutsideUser from
+ /// ReductionInfo is an add instruction that precedes the reduction.
+ /// UnrollInfo pairs values together if they are both operands of the same
+ /// add. This pairing info is then used to add the resulting complex
+ /// operations together before the final reduction.
+ MapVector<Value *, Value *> UnrollInfo;
+
/// In the process of detecting a reduction, we consider a pair of
/// %ReductionOP, which we refer to as real and imag (or vice versa), and
/// traverse the use-tree to detect complex operations. As this is a reduction
@@ -2253,8 +2261,31 @@ void ComplexDeinterleavingGraph::processReductionSingle(
auto *FinalReduction = ReductionInfo[Real].second;
Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
- auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
+ Value *Other;
+ bool EraseFinalReductionHere = false;
+ if (match(FinalReduction, m_c_Add(m_Specific(Real), m_Value(Other)))) {
+ UnrollInfo[Real] = OperationReplacement;
+ if (!UnrollInfo.contains(Other) || !FinalReduction->hasOneUser())
+ return;
+
+ auto *User = *FinalReduction->user_begin();
+ if (!match(User, m_Intrinsic<Intrinsic::vector_reduce_add>()))
+ return;
+
+ FinalReduction = cast<Instruction>(User);
+ Builder.SetInsertPoint(FinalReduction);
+ OperationReplacement =
+ Builder.CreateAdd(OperationReplacement, UnrollInfo[Other]);
+
+ UnrollInfo.erase(Real);
+ UnrollInfo.erase(Other);
+ EraseFinalReductionHere = true;
+ }
+
+ Value *AddReduce = Builder.CreateAddReduce(OperationReplacement);
FinalReduction->replaceAllUsesWith(AddReduce);
+ if (EraseFinalReductionHere)
+ FinalReduction->eraseFromParent();
}
void ComplexDeinterleavingGraph::processReductionOperation(
@@ -2299,7 +2330,7 @@ void ComplexDeinterleavingGraph::processReductionOperation(
}
void ComplexDeinterleavingGraph::replaceNodes() {
- SmallVector<Instruction *, 16> DeadInstrRoots;
+ SmallSetVector<Instruction *, 16> DeadInstrRoots;
for (auto *RootInstruction : OrderedRoots) {
// Check if this potential root went through check process and we can
// deinterleave it
@@ -2316,20 +2347,23 @@ void ComplexDeinterleavingGraph::replaceNodes() {
auto *RootImag = cast<Instruction>(RootNode->Imag);
ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
- DeadInstrRoots.push_back(RootReal);
- DeadInstrRoots.push_back(RootImag);
+ DeadInstrRoots.insert(RootReal);
+ DeadInstrRoots.insert(RootImag);
} else if (RootNode->Operation ==
ComplexDeinterleavingOperation::ReductionSingle) {
auto *RootInst = cast<Instruction>(RootNode->Real);
ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
- DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
+ DeadInstrRoots.insert(ReductionInfo[RootInst].second);
} else {
assert(R && "Unable to find replacement for RootInstruction");
- DeadInstrRoots.push_back(RootInstruction);
+ DeadInstrRoots.insert(RootInstruction);
RootInstruction->replaceAllUsesWith(R);
}
}
+ assert(UnrollInfo.empty() &&
+ "UnrollInfo should be empty after replacing all nodes");
+
for (auto *I : DeadInstrRoots)
RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
}
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-unrolled-cdot.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-unrolled-cdot.ll
new file mode 100644
index 0000000000000..e680fd883a1ac
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-unrolled-cdot.ll
@@ -0,0 +1,181 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S --passes=complex-deinterleaving %s --mattr=+sve2 -o - | FileCheck %s --check-prefix=CHECK-SVE2
+; RUN: opt -S --passes=complex-deinterleaving %s --mattr=+sve -o - | FileCheck %s --check-prefix=CHECK-SVE
+; RUN: opt -S --passes=complex-deinterleaving %s -o - | FileCheck %s --check-prefix=CHECK-NOSVE
+
+target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
+target triple = "aarch64-none-unknown-elf"
+
+define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a0, <vscale x 32 x i8> %b0, <vscale x 32 x i8> %a1, <vscale x 32 x i8> %b1) {
+; CHECK-SVE2-LABEL: define i32 @cdotp_i8_rot0(
+; CHECK-SVE2-SAME: <vscale x 32 x i8> [[A0:%.*]], <vscale x 32 x i8> [[B0:%.*]], <vscale x 32 x i8> [[A1:%.*]], <vscale x 32 x i8> [[B1:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-SVE2-NEXT: [[ENTRY:.*]]:
+; CHECK-SVE2-NEXT: br label %[[VECTOR_BODY:.*]]
+; CHECK-SVE2: [[VECTOR_BODY]]:
+; CHECK-SVE2-NEXT: [[TMP0:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[TMP11:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-SVE2-NEXT: [[TMP1:%.*]] = phi <vscale x 8 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[TMP21:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-SVE2-NEXT: [[TMP2:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A0]], i64 0)
+; CHECK-SVE2-NEXT: [[TMP3:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B0]], i64 0)
+; CHECK-SVE2-NEXT: [[TMP4:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A0]], i64 16)
+; CHECK-SVE2-NEXT: [[TMP5:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B0]], i64 16)
+; CHECK-SVE2-NEXT: [[TMP6:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP0]], i64 0)
+; CHECK-SVE2-NEXT: [[TMP7:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP0]], i64 4)
+; CHECK-SVE2-NEXT: [[TMP8:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP6]], <vscale x 16 x i8> [[TMP2]], <vscale x 16 x i8> [[TMP3]], i32 0)
+; CHECK-SVE2-NEXT: [[TMP9:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP7]], <vscale x 16 x i8> [[TMP4]], <vscale x 16 x i8> [[TMP5]], i32 0)
+; CHECK-SVE2-NEXT: [[TMP10:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> [[TMP8]], i64 0)
+; CHECK-SVE2-NEXT: [[TMP11]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> [[TMP10]], <vscale x 4 x i32> [[TMP9]], i64 4)
+; CHECK-SVE2-NEXT: [[TMP12:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A1]], i64 0)
+; CHECK-SVE2-NEXT: [[TMP13:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B1]], i64 0)
+; CHECK-SVE2-NEXT: [[TMP14:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[A1]], i64 16)
+; CHECK-SVE2-NEXT: [[TMP15:%.*]] = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> [[B1]], i64 16)
+; CHECK-SVE2-NEXT: [[TMP16:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP1]], i64 0)
+; CHECK-SVE2-NEXT: [[TMP17:%.*]] = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> [[TMP1]], i64 4)
+; CHECK-SVE2-NEXT: [[TMP18:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP16]], <vscale x 16 x i8> [[TMP12]], <vscale x 16 x i8> [[TMP13]], i32 0)
+; CHECK-SVE2-NEXT: [[TMP19:%.*]] = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> [[TMP17]], <vscale x 16 x i8> [[TMP14]], <vscale x 16 x i8> [[TMP15]], i32 0)
+; CHECK-SVE2-NEXT: [[TMP20:%.*]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> [[TMP18]], i64 0)
+; CHECK-SVE2-NEXT: [[TMP21]] = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> [[TMP20]], <vscale x 4 x i32> [[TMP19]], i64 4)
+; CHECK-SVE2-NEXT: br i1 true, label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]]
+; CHECK-SVE2: [[MIDDLE_BLOCK]]:
+; CHECK-SVE2-NEXT: [[TMP22:%.*]] = add <vscale x 8 x i32> [[TMP21]], [[TMP11]]
+; CHECK-SVE2-NEXT: [[TMP23:%.*]] = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> [[TMP22]])
+; CHECK-SVE2-NEXT: ret i32 [[TMP23]]
+;
+; CHECK-SVE-LABEL: define i32 @cdotp_i8_rot0(
+; CHECK-SVE-SAME: <vscale x 32 x i8> [[A0:%.*]], <vscale x 32 x i8> [[B0:%.*]], <vscale x 32 x i8> [[A1:%.*]], <vscale x 32 x i8> [[B1:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-SVE-NEXT: [[ENTRY:.*]]:
+; CHECK-SVE-NEXT: br label %[[VECTOR_BODY:.*]]
+; CHECK-SVE: [[VECTOR_BODY]]:
+; CHECK-SVE-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[PARTIAL_REDUCE33:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT: [[VEC_PHI25:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[PARTIAL_REDUCE34:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-SVE-NEXT: [[A0_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[A0]])
+; CHECK-SVE-NEXT: [[A0_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A0_DEINTERLEAVED]], 0
+; CHECK-SVE-NEXT: [[A0_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A0_DEINTERLEAVED]], 1
+; CHECK-SVE-NEXT: [[A1_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[A1]])
+; CHECK-SVE-NEXT: [[A1_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A1_DEINTERLEAVED]], 0
+; CHECK-SVE-NEXT: [[A1_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A1_DEINTERLEAVED]], 1
+; CHECK-SVE-NEXT: [[A0_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[A0_REAL]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT: [[A1_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[A1_REAL]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT: [[B0_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[B0]])
+; CHECK-SVE-NEXT: [[B0_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B0_DEINTERLEAVED]], 0
+; CHECK-SVE-NEXT: [[B0_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B0_DEINTERLEAVED]], 1
+; CHECK-SVE-NEXT: [[B1_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[B1]])
+; CHECK-SVE-NEXT: [[B1_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B1_DEINTERLEAVED]], 0
+; CHECK-SVE-NEXT: [[B1_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B1_DEINTERLEAVED]], 1
+; CHECK-SVE-NEXT: [[B0_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[B0_REAL]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT: [[B1_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[B1_REAL]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT: [[TMP0:%.*]] = mul nsw <vscale x 16 x i32> [[B0_REAL_EXT]], [[A0_REAL_EXT]]
+; CHECK-SVE-NEXT: [[TMP1:%.*]] = mul nsw <vscale x 16 x i32> [[B1_REAL_EXT]], [[A1_REAL_EXT]]
+; CHECK-SVE-NEXT: [[A0_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[A0_IMAG]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT: [[A1_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[A1_IMAG]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT: [[B0_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[B0_IMAG]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT: [[B1_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[B1_IMAG]] to <vscale x 16 x i32>
+; CHECK-SVE-NEXT: [[TMP2:%.*]] = mul nsw <vscale x 16 x i32> [[B0_IMAG_EXT]], [[A0_IMAG_EXT]]
+; CHECK-SVE-NEXT: [[TMP3:%.*]] = mul nsw <vscale x 16 x i32> [[B1_IMAG_EXT]], [[A1_IMAG_EXT]]
+; CHECK-SVE-NEXT: [[PARTIAL_REDUCE:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP0]])
+; CHECK-SVE-NEXT: [[PARTIAL_REDUCE32:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI25]], <vscale x 16 x i32> [[TMP1]])
+; CHECK-SVE-NEXT: [[TMP4:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP2]]
+; CHECK-SVE-NEXT: [[TMP5:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP3]]
+; CHECK-SVE-NEXT: [[PARTIAL_REDUCE33]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP4]])
+; CHECK-SVE-NEXT: [[PARTIAL_REDUCE34]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE32]], <vscale x 16 x i32> [[TMP5]])
+; CHECK-SVE-NEXT: br i1 true, label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]]
+; CHECK-SVE: [[MIDDLE_BLOCK]]:
+; CHECK-SVE-NEXT: [[BIN_RDX:%.*]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE34]], [[PARTIAL_REDUCE33]]
+; CHECK-SVE-NEXT: [[TMP6:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[BIN_RDX]])
+; CHECK-SVE-NEXT: ret i32 [[TMP6]]
+;
+; CHECK-NOSVE-LABEL: define i32 @cdotp_i8_rot0(
+; CHECK-NOSVE-SAME: <vscale x 32 x i8> [[A0:%.*]], <vscale x 32 x i8> [[B0:%.*]], <vscale x 32 x i8> [[A1:%.*]], <vscale x 32 x i8> [[B1:%.*]]) {
+; CHECK-NOSVE-NEXT: [[ENTRY:.*]]:
+; CHECK-NOSVE-NEXT: br label %[[VECTOR_BODY:.*]]
+; CHECK-NOSVE: [[VECTOR_BODY]]:
+; CHECK-NOSVE-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[PARTIAL_REDUCE33:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NOSVE-NEXT: [[VEC_PHI25:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, %[[ENTRY]] ], [ [[PARTIAL_REDUCE34:%.*]], %[[VECTOR_BODY]] ]
+; CHECK-NOSVE-NEXT: [[A0_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[A0]])
+; CHECK-NOSVE-NEXT: [[A0_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A0_DEINTERLEAVED]], 0
+; CHECK-NOSVE-NEXT: [[A0_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A0_DEINTERLEAVED]], 1
+; CHECK-NOSVE-NEXT: [[A1_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[A1]])
+; CHECK-NOSVE-NEXT: [[A1_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A1_DEINTERLEAVED]], 0
+; CHECK-NOSVE-NEXT: [[A1_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[A1_DEINTERLEAVED]], 1
+; CHECK-NOSVE-NEXT: [[A0_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[A0_REAL]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT: [[A1_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[A1_REAL]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT: [[B0_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[B0]])
+; CHECK-NOSVE-NEXT: [[B0_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B0_DEINTERLEAVED]], 0
+; CHECK-NOSVE-NEXT: [[B0_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B0_DEINTERLEAVED]], 1
+; CHECK-NOSVE-NEXT: [[B1_DEINTERLEAVED:%.*]] = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> [[B1]])
+; CHECK-NOSVE-NEXT: [[B1_REAL:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B1_DEINTERLEAVED]], 0
+; CHECK-NOSVE-NEXT: [[B1_IMAG:%.*]] = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } [[B1_DEINTERLEAVED]], 1
+; CHECK-NOSVE-NEXT: [[B0_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[B0_REAL]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT: [[B1_REAL_EXT:%.*]] = sext <vscale x 16 x i8> [[B1_REAL]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT: [[TMP0:%.*]] = mul nsw <vscale x 16 x i32> [[B0_REAL_EXT]], [[A0_REAL_EXT]]
+; CHECK-NOSVE-NEXT: [[TMP1:%.*]] = mul nsw <vscale x 16 x i32> [[B1_REAL_EXT]], [[A1_REAL_EXT]]
+; CHECK-NOSVE-NEXT: [[A0_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[A0_IMAG]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT: [[A1_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[A1_IMAG]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT: [[B0_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[B0_IMAG]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT: [[B1_IMAG_EXT:%.*]] = sext <vscale x 16 x i8> [[B1_IMAG]] to <vscale x 16 x i32>
+; CHECK-NOSVE-NEXT: [[TMP2:%.*]] = mul nsw <vscale x 16 x i32> [[B0_IMAG_EXT]], [[A0_IMAG_EXT]]
+; CHECK-NOSVE-NEXT: [[TMP3:%.*]] = mul nsw <vscale x 16 x i32> [[B1_IMAG_EXT]], [[A1_IMAG_EXT]]
+; CHECK-NOSVE-NEXT: [[PARTIAL_REDUCE:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI]], <vscale x 16 x i32> [[TMP0]])
+; CHECK-NOSVE-NEXT: [[PARTIAL_REDUCE32:%.*]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[VEC_PHI25]], <vscale x 16 x i32> [[TMP1]])
+; CHECK-NOSVE-NEXT: [[TMP4:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP2]]
+; CHECK-NOSVE-NEXT: [[TMP5:%.*]] = sub nsw <vscale x 16 x i32> zeroinitializer, [[TMP3]]
+; CHECK-NOSVE-NEXT: [[PARTIAL_REDUCE33]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE]], <vscale x 16 x i32> [[TMP4]])
+; CHECK-NOSVE-NEXT: [[PARTIAL_REDUCE34]] = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> [[PARTIAL_REDUCE32]], <vscale x 16 x i32> [[TMP5]])
+; CHECK-NOSVE-NEXT: br i1 true, label %[[MIDDLE_BLOCK:.*]], label %[[VECTOR_BODY]]
+; CHECK-NOSVE: [[MIDDLE_BLOCK]]:
+; CHECK-NOSVE-NEXT: [[BIN_RDX:%.*]] = add <vscale x 4 x i32> [[PARTIAL_REDUCE34]], [[PARTIAL_REDUCE33]]
+; CHECK-NOSVE-NEXT: [[TMP6:%.*]] = tail call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[BIN_RDX]])
+; CHECK-NOSVE-NEXT: ret i32 [[TMP6]]
+;
+entry:
+ br label %vector.body
+
+vector.body: ; preds = %vector.body, %entry
+ %vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce33, %vector.body ]
+ %vec.phi25 = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce34, %vector.body ]
+ %a0.deinterleaved = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %a0)
+ %a0.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a0.deinterleaved, 0
+ %a0.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a0.deinterleaved, 1
+ %a1.deinterleaved = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %a1)
+ %a1.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a1.deinterleaved, 0
+ %a1.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a1.deinterleaved, 1
+ %a0.real.ext = sext <vscale x 16 x i8> %a0.real to <vscale x 16 x i32>
+ %a1.real.ext = sext <vscale x 16 x i8> %a1.real to <vscale x 16 x i32>
+ %b0.deinterleaved = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %b0)
+ %b0.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b0.deinterleaved, 0
+ %b0.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b0.deinterleaved, 1
+ %b1.deinterleaved = tail call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %b1)
+ %b1.real = extractvalue { <vs...
[truncated]
|
Ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While I was able to spot some issues, it's probably better if someone with more experience in this pass has a look at it as well.
%vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce33, %vector.body ] | ||
%vec.phi25 = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce34, %vector.body ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test where the unroll factor is larger than 2, e.g. 4?
auto *AddReduce = Builder.CreateAddReduce(OperationReplacement); | ||
Value *Other; | ||
bool EraseFinalReductionHere = false; | ||
if (match(FinalReduction, m_c_Add(m_Specific(Real), m_Value(Other)))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I replace the add
in this test, by a sub
, the pass still crashes, so this is not sufficient.
Does it matter what the operation (the one outside the loop) actually is?
I would have expected something like this:
define <vscale x 4 x i32> @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) {
entry:
br label %vector.body
vector.body: ; preds = %vector.body, %entry
%vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce.sub, %vector.body ]
%a.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.v32i8(<vscale x 32 x i8> %a)
%b.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.v32i8(<vscale x 32 x i8> %b)
%a.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 0
%a.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 1
%b.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 0
%b.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 1
%a.real.ext = sext <vscale x 16 x i8> %a.real to <vscale x 16 x i32>
%a.imag.ext = sext <vscale x 16 x i8> %a.imag to <vscale x 16 x i32>
%b.real.ext = sext <vscale x 16 x i8> %b.real to <vscale x 16 x i32>
%b.imag.ext = sext <vscale x 16 x i8> %b.imag to <vscale x 16 x i32>
%real.mul = mul <vscale x 16 x i32> %b.real.ext, %a.real.ext
%real.mul.reduced = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi, <vscale x 16 x i32> %real.mul)
%imag.mul = mul <vscale x 16 x i32> %b.imag.ext, %a.imag.ext
%imag.mul.neg = sub <vscale x 16 x i32> zeroinitializer, %imag.mul
%partial.reduce.sub = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %real.mul.reduced, <vscale x 16 x i32> %imag.mul.neg)
br i1 true, label %middle.block, label %vector.body
middle.block: ; preds = %vector.body
ret <vscale x 4 x i32> %partial.reduce.sub
}
to use cdot
instructions as well, but this case also seems to crash. This suggests that the issue is not to do with unrolling, but rather with the user outside the loop being anything else than a reduction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is a separate issue; I wouldn't expect that snippet to be processed as the current implementation would require changing the function return type. Whereas the issue this patch is aimed at fixing is when it tries to change one operand of an add with a value of a different type.
If this were to instead reinterleave and store the complex result in middle.block
, instead of returning it, then I would expect the pass to process it and emit cdot
instructions.
When I replace the add in this test, by a sub, the pass still crashes, so this is not sufficient.
I'm not sure if the loop vectorizer would ever emit a sub here. Please do correct me if I'm wrong, but I'm not seeing any VECREDUCE_ADD
or vecreduce.add
equivalent for subtraction, and the instruction of %bin.rdx
in this case is derived from the reduction intrinsic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For regular reductions (without cdot), we needed to analyse and rewrite use outside of the loop due to Real and Imaginary part extraction. See cases in complex-deinterleaving-reductions.ll
. But for cdot
, we don't need to do any of that. Here's a test from complex-deinterleaving-cdot.ll
:
define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) {
entry:
br label %vector.body
vector.body: ; preds = %vector.body, %entry
%vec.phi = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %partial.reduce.sub, %vector.body ]
%a.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %a)
%b.deinterleaved = call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.vector.deinterleave2.nxv32i8(<vscale x 32 x i8> %b)
%a.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 0
%a.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %a.deinterleaved, 1
%b.real = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 0
%b.imag = extractvalue { <vscale x 16 x i8>, <vscale x 16 x i8> } %b.deinterleaved, 1
%a.real.ext = sext <vscale x 16 x i8> %a.real to <vscale x 16 x i32>
%a.imag.ext = sext <vscale x 16 x i8> %a.imag to <vscale x 16 x i32>
%b.real.ext = sext <vscale x 16 x i8> %b.real to <vscale x 16 x i32>
%b.imag.ext = sext <vscale x 16 x i8> %b.imag to <vscale x 16 x i32>
%real.mul = mul <vscale x 16 x i32> %b.real.ext, %a.real.ext
%real.mul.reduced = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %vec.phi, <vscale x 16 x i32> %real.mul)
%imag.mul = mul <vscale x 16 x i32> %b.imag.ext, %a.imag.ext
%imag.mul.neg = sub <vscale x 16 x i32> zeroinitializer, %imag.mul
%partial.reduce.sub = call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %real.mul.reduced, <vscale x 16 x i32> %imag.mul.neg)
br i1 true, label %middle.block, label %vector.body
middle.block: ; preds = %vector.body
%0 = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %partial.reduce.sub)
ret i32 %0
}
It is currently transformed into:
efine i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) #0 {
entry:
br label %vector.body
vector.body: ; preds = %vector.body, %entry
%0 = phi <vscale x 8 x i32> [ zeroinitializer, %entry ], [ %10, %vector.body ]
%1 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 0)
%2 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 0)
%3 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 16)
%4 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 16)
%5 = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> %0, i64 0)
%6 = call <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32> %0, i64 4)
%7 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %5, <vscale x 16 x i8> %1, <vscale x 16 x i8> %2, i32 0)
%8 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %6, <vscale x 16 x i8> %3, <vscale x 16 x i8> %4, i32 0)
%9 = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> poison, <vscale x 4 x i32> %7, i64 0)
%10 = call <vscale x 8 x i32> @llvm.vector.insert.nxv8i32.nxv4i32(<vscale x 8 x i32> %9, <vscale x 4 x i32> %8, i64 4)
br i1 true, label %middle.block, label %vector.body
middle.block: ; preds = %vector.body
%11 = call i32 @llvm.vector.reduce.add.nxv8i32(<vscale x 8 x i32> %10)
ret i32 %11
}
But instead, we could ignore everything happening after the final llvm.experimental.vector.partial.reduce.add and just put one cdot on another:
define i32 @cdotp_i8_rot0(<vscale x 32 x i8> %a, <vscale x 32 x i8> %b) #0 {
entry:
br label %vector.body
vector.body: ; preds = %vector.body, %entry
%0 = phi <vscale x 4 x i32> [ zeroinitializer, %entry ], [ %8, %vector.body ]
%1 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 0)
%2 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 0)
%3 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %a, i64 16)
%4 = call <vscale x 16 x i8> @llvm.vector.extract.nxv16i8.nxv32i8(<vscale x 32 x i8> %b, i64 16)
%7 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %0, <vscale x 16 x i8> %1, <vscale x 16 x i8> %2, i32 0)
%8 = call <vscale x 4 x i32> @llvm.aarch64.sve.cdot.nxv4i32(<vscale x 4 x i32> %7, <vscale x 16 x i8> %3, <vscale x 16 x i8> %4, i32 0)
br i1 true, label %middle.block, label %vector.body
middle.block: ; preds = %vector.body
%result = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> %8)
ret i32 %result
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @igogo-x86, I agree it makes more sense to feed the result from one cdot
into the other, rather than changing the PHI node to be a wider type.
My understanding is that the Complex Deinterleaving pass was created because for certain operations (like reductions) where there is both a PHI+reduction for the imaginary and one PHI+reduction for the real part, it is better to keep the intermediate values interleaved, because the cmla
instruction takes interleaved tuples as input and returns interleaved tuples as output. This avoids having to deinterleave values first and it also allows using specialised cmla
instructions to do the complex MLA operation. The reduction PHI then contains a vector of <(r, i), (r, i), ..>
tuples, which need de-interleaving only when doing the final reduction.
For the case of cdot
instructions there is no need for this, because the result vector will always be deinterleaved (the cdot
instruction returns either a widened real, or a widened imaginary result).
If that understanding is correct, then I don't really see a need to implement this optimization in the ComplexDeinterleave pass. This looks more like a DAGcombine of partialreduce(mul(ext(deinterleave(a)), ext(deinterleave(b))) -> cdot(a, b, #0)
(with some variation of this pattern for other rotations). With the new ISD node ISD::PARTIAL_REDUCE_[U|S]MLA
added by @JamesChesterman this should be even easier to identify.
Please let me know if I'm missing anything here though.
Instead of trying to accommodate the unexpected types, I've instead altered this patch to simply not attempt to process reduction loops that would result in a non-complex or non-vector value. In the interest of fixing the crash, this seems like the simplest fix. We can lessen the restrictions in future to accommodate more use cases in the future if necessary. |
Sounds good! Maybe it's a good idea to add a TODO to fix that, even if you plan to do it right away? |
/cherry-pick 3f4b2f1 |
/pull-request #132031 |
…loop (llvm#129735) When attempting to perform complex deinterleaving on an unrolled loop containing a reduction, the complex deinterleaving pass would fail to accommodate the wider types when accumulating the unrolled paths. Instead of trying to alter the incoming IR to fit expectations, the pass should instead decide against processing any reduction that results in a non-complex or non-vector value. (cherry picked from commit 3f4b2f1)
When attempting to perform complex deinterleaving on an unrolled loop containing a reduction, the complex deinterleaving pass would fail to accommodate the wider types when accumulating the unrolled paths. Instead of trying to alter the incoming IR to fit expectations, the pass should instead decide against processing any reduction that results in a non-complex or non-vector value.