Skip to content

Commit 9644e2a

Browse files
committed
[Machine-Combiner] Add pattern to rewrite chains of MLA instructions into a tree for increased ILP
1 parent 7af7c59 commit 9644e2a

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6709,6 +6709,7 @@ bool AArch64InstrInfo::isAccumulationOpcode(unsigned Opcode) const {
67096709
case AArch64::SABAv4i32:
67106710
case AArch64::SABAv8i16:
67116711
case AArch64::SABAv8i8:
6712+
case AArch64::MLAv8i8:
67126713
return true;
67136714
}
67146715

@@ -6720,6 +6721,8 @@ std::optional<unsigned> AArch64InstrInfo::getAccumulationStartOpcode(
67206721
switch (AccumulationOpcode) {
67216722
default:
67226723
llvm_unreachable("Unknown accumulator opcode");
6724+
case AArch64::MLAv8i8:
6725+
return AArch64::MULv8i8;
67236726
case AArch64::UABALB_ZZZ_D:
67246727
return AArch64::UABDLB_ZZZ_D;
67256728
case AArch64::UABALB_ZZZ_H:
@@ -7593,6 +7596,7 @@ std::optional<unsigned> AArch64InstrInfo::getReduceOpcodeForAccumulator(
75937596
return AArch64::ADDv2i32;
75947597
case AArch64::UABAv8i8:
75957598
case AArch64::SABAv8i8:
7599+
case AArch64::MLAv8i8:
75967600
return AArch64::ADDv8i8;
75977601
default:
75987602
llvm_unreachable("Unknown accumulator opcode");

llvm/test/CodeGen/AArch64/aarch64-reassociate-accumulators.ll

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,36 @@
1-
; RUN: opt -passes=loop-unroll %s -o - | llc -O3 - -mtriple=arm64e-apple-darwin -o - | FileCheck %s
1+
; RUN: opt -passes=loop-unroll %s -o - | llc -O3 - -mtriple=arm64e-apple-darwin -machine-combiner-recurse -o - | FileCheck %s
22

3+
define i8 @mla_i8_accumulation(ptr %ptr1, ptr %ptr2) {
4+
entry:
5+
br label %loop
6+
loop:
7+
%i = phi i32 [ 0, %entry ], [ %next_i, %loop ]
8+
%acc_phi = phi <8 x i8> [ zeroinitializer, %entry ], [ %acc_next, %loop ]
9+
%ptr1_i = getelementptr i8, ptr %ptr1, i32 %i
10+
%ptr2_i = getelementptr i8, ptr %ptr2, i32 %i
11+
%a = load <8 x i8>, <8 x i8>* %ptr1_i, align 1
12+
%b = load <8 x i8>, <8 x i8>* %ptr2_i, align 1
13+
%mul = mul <8 x i8> %a, %b
14+
%acc_next = add <8 x i8> %acc_phi, %mul
15+
%next_i = add i32 %i, 8
16+
%cmp = icmp slt i32 %next_i, 64
17+
br i1 %cmp, label %loop, label %exit
18+
exit:
19+
%reduce = call i8 @llvm.vector.reduce.add.v8i8(<8 x i8> %acc_next)
20+
ret i8 %reduce
21+
}
22+
; CHECK-LABEL: mla_i8_accumulation
23+
; CHECK: mul.8b v1
24+
; CHECK: mul.8b v0
25+
; CHECK: mul.8b v2
26+
; CHECK: mla.8b v1
27+
; CHECK: mla.8b v0
28+
; CHECK: mla.8b v2
29+
; CHECK: mla.8b v1
30+
; CHECK: mla.8b v0
31+
; CHECK: add.8b v1, v2, v1
32+
; CHECK: add.8b v0, v1, v0
33+
; CHECK: addv.8b
334

435
define i16 @sabal_i8_to_i16_accumulation(ptr %ptr1, ptr %ptr2) {
536
entry:

0 commit comments

Comments
 (0)