Skip to content

Commit b702276

Browse files
committed
[AArch64] Add Machine InstCombiner patterns for FMUL indexed variant
This patch adds DUP+FMUL => FMUL_indexed pattern to InstCombiner. FMUL_indexed is normally selected during instruction selection, but it does not work in cases when VDUP and VMUL are in different basic blocks. Differential Revision: https://reviews.llvm.org/D99662
1 parent 0076957 commit b702276

File tree

4 files changed

+825
-3
lines changed

4 files changed

+825
-3
lines changed

llvm/include/llvm/CodeGen/MachineCombinerPattern.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,18 @@ enum class MachineCombinerPattern {
153153
FMLSv4f32_OP1,
154154
FMLSv4f32_OP2,
155155
FMLSv4i32_indexed_OP1,
156-
FMLSv4i32_indexed_OP2
156+
FMLSv4i32_indexed_OP2,
157+
158+
FMULv2i32_indexed_OP1,
159+
FMULv2i32_indexed_OP2,
160+
FMULv2i64_indexed_OP1,
161+
FMULv2i64_indexed_OP2,
162+
FMULv4i16_indexed_OP1,
163+
FMULv4i16_indexed_OP2,
164+
FMULv4i32_indexed_OP1,
165+
FMULv4i32_indexed_OP2,
166+
FMULv8i16_indexed_OP1,
167+
FMULv8i16_indexed_OP2,
157168
};
158169

159170
} // end namespace llvm

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4917,6 +4917,55 @@ static bool getFMAPatterns(MachineInstr &Root,
49174917
return Found;
49184918
}
49194919

4920+
static bool getFMULPatterns(MachineInstr &Root,
4921+
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
4922+
MachineBasicBlock &MBB = *Root.getParent();
4923+
bool Found = false;
4924+
4925+
auto Match = [&](unsigned Opcode, int Operand,
4926+
MachineCombinerPattern Pattern) -> bool {
4927+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
4928+
MachineOperand &MO = Root.getOperand(Operand);
4929+
MachineInstr *MI = nullptr;
4930+
if (MO.isReg() && Register::isVirtualRegister(MO.getReg()))
4931+
MI = MRI.getUniqueVRegDef(MO.getReg());
4932+
if (MI && MI->getOpcode() == Opcode) {
4933+
Patterns.push_back(Pattern);
4934+
return true;
4935+
}
4936+
return false;
4937+
};
4938+
4939+
typedef MachineCombinerPattern MCP;
4940+
4941+
switch (Root.getOpcode()) {
4942+
default:
4943+
return false;
4944+
case AArch64::FMULv2f32:
4945+
Found = Match(AArch64::DUPv2i32lane, 1, MCP::FMULv2i32_indexed_OP1);
4946+
Found |= Match(AArch64::DUPv2i32lane, 2, MCP::FMULv2i32_indexed_OP2);
4947+
break;
4948+
case AArch64::FMULv2f64:
4949+
Found = Match(AArch64::DUPv2i64lane, 1, MCP::FMULv2i64_indexed_OP1);
4950+
Found |= Match(AArch64::DUPv2i64lane, 2, MCP::FMULv2i64_indexed_OP2);
4951+
break;
4952+
case AArch64::FMULv4f16:
4953+
Found = Match(AArch64::DUPv4i16lane, 1, MCP::FMULv4i16_indexed_OP1);
4954+
Found |= Match(AArch64::DUPv4i16lane, 2, MCP::FMULv4i16_indexed_OP2);
4955+
break;
4956+
case AArch64::FMULv4f32:
4957+
Found = Match(AArch64::DUPv4i32lane, 1, MCP::FMULv4i32_indexed_OP1);
4958+
Found |= Match(AArch64::DUPv4i32lane, 2, MCP::FMULv4i32_indexed_OP2);
4959+
break;
4960+
case AArch64::FMULv8f16:
4961+
Found = Match(AArch64::DUPv8i16lane, 1, MCP::FMULv8i16_indexed_OP1);
4962+
Found |= Match(AArch64::DUPv8i16lane, 2, MCP::FMULv8i16_indexed_OP2);
4963+
break;
4964+
}
4965+
4966+
return Found;
4967+
}
4968+
49204969
/// Return true when a code sequence can improve throughput. It
49214970
/// should be called only for instructions in loops.
49224971
/// \param Pattern - combiner pattern
@@ -4980,6 +5029,16 @@ bool AArch64InstrInfo::isThroughputPattern(
49805029
case MachineCombinerPattern::FMLSv2f64_OP2:
49815030
case MachineCombinerPattern::FMLSv4i32_indexed_OP2:
49825031
case MachineCombinerPattern::FMLSv4f32_OP2:
5032+
case MachineCombinerPattern::FMULv2i32_indexed_OP1:
5033+
case MachineCombinerPattern::FMULv2i32_indexed_OP2:
5034+
case MachineCombinerPattern::FMULv2i64_indexed_OP1:
5035+
case MachineCombinerPattern::FMULv2i64_indexed_OP2:
5036+
case MachineCombinerPattern::FMULv4i16_indexed_OP1:
5037+
case MachineCombinerPattern::FMULv4i16_indexed_OP2:
5038+
case MachineCombinerPattern::FMULv4i32_indexed_OP1:
5039+
case MachineCombinerPattern::FMULv4i32_indexed_OP2:
5040+
case MachineCombinerPattern::FMULv8i16_indexed_OP1:
5041+
case MachineCombinerPattern::FMULv8i16_indexed_OP2:
49835042
case MachineCombinerPattern::MULADDv8i8_OP1:
49845043
case MachineCombinerPattern::MULADDv8i8_OP2:
49855044
case MachineCombinerPattern::MULADDv16i8_OP1:
@@ -5036,6 +5095,8 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
50365095
if (getMaddPatterns(Root, Patterns))
50375096
return true;
50385097
// Floating point patterns
5098+
if (getFMULPatterns(Root, Patterns))
5099+
return true;
50395100
if (getFMAPatterns(Root, Patterns))
50405101
return true;
50415102

@@ -5124,6 +5185,42 @@ genFusedMultiply(MachineFunction &MF, MachineRegisterInfo &MRI,
51245185
return MUL;
51255186
}
51265187

5188+
/// Fold (FMUL x (DUP y lane)) into (FMUL_indexed x y lane)
5189+
static MachineInstr *
5190+
genIndexedMultiply(MachineInstr &Root,
5191+
SmallVectorImpl<MachineInstr *> &InsInstrs,
5192+
unsigned IdxDupOp, unsigned MulOpc,
5193+
const TargetRegisterClass *RC, MachineRegisterInfo &MRI) {
5194+
assert(((IdxDupOp == 1) || (IdxDupOp == 2)) &&
5195+
"Invalid index of FMUL operand");
5196+
5197+
MachineFunction &MF = *Root.getMF();
5198+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
5199+
5200+
MachineInstr *Dup =
5201+
MF.getRegInfo().getUniqueVRegDef(Root.getOperand(IdxDupOp).getReg());
5202+
5203+
Register DupSrcReg = Dup->getOperand(1).getReg();
5204+
MRI.clearKillFlags(DupSrcReg);
5205+
MRI.constrainRegClass(DupSrcReg, RC);
5206+
5207+
unsigned DupSrcLane = Dup->getOperand(2).getImm();
5208+
5209+
unsigned IdxMulOp = IdxDupOp == 1 ? 2 : 1;
5210+
MachineOperand &MulOp = Root.getOperand(IdxMulOp);
5211+
5212+
Register ResultReg = Root.getOperand(0).getReg();
5213+
5214+
MachineInstrBuilder MIB;
5215+
MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MulOpc), ResultReg)
5216+
.add(MulOp)
5217+
.addReg(DupSrcReg)
5218+
.addImm(DupSrcLane);
5219+
5220+
InsInstrs.push_back(MIB);
5221+
return &Root;
5222+
}
5223+
51275224
/// genFusedMultiplyAcc - Helper to generate fused multiply accumulate
51285225
/// instructions.
51295226
///
@@ -6082,12 +6179,53 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
60826179
}
60836180
break;
60846181
}
6182+
case MachineCombinerPattern::FMULv2i32_indexed_OP1:
6183+
case MachineCombinerPattern::FMULv2i32_indexed_OP2: {
6184+
unsigned IdxDupOp =
6185+
(Pattern == MachineCombinerPattern::FMULv2i32_indexed_OP1) ? 1 : 2;
6186+
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv2i32_indexed,
6187+
&AArch64::FPR128RegClass, MRI);
6188+
break;
6189+
}
6190+
case MachineCombinerPattern::FMULv2i64_indexed_OP1:
6191+
case MachineCombinerPattern::FMULv2i64_indexed_OP2: {
6192+
unsigned IdxDupOp =
6193+
(Pattern == MachineCombinerPattern::FMULv2i64_indexed_OP1) ? 1 : 2;
6194+
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv2i64_indexed,
6195+
&AArch64::FPR128RegClass, MRI);
6196+
break;
6197+
}
6198+
case MachineCombinerPattern::FMULv4i16_indexed_OP1:
6199+
case MachineCombinerPattern::FMULv4i16_indexed_OP2: {
6200+
unsigned IdxDupOp =
6201+
(Pattern == MachineCombinerPattern::FMULv4i16_indexed_OP1) ? 1 : 2;
6202+
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv4i16_indexed,
6203+
&AArch64::FPR128_loRegClass, MRI);
6204+
break;
6205+
}
6206+
case MachineCombinerPattern::FMULv4i32_indexed_OP1:
6207+
case MachineCombinerPattern::FMULv4i32_indexed_OP2: {
6208+
unsigned IdxDupOp =
6209+
(Pattern == MachineCombinerPattern::FMULv4i32_indexed_OP1) ? 1 : 2;
6210+
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv4i32_indexed,
6211+
&AArch64::FPR128RegClass, MRI);
6212+
break;
6213+
}
6214+
case MachineCombinerPattern::FMULv8i16_indexed_OP1:
6215+
case MachineCombinerPattern::FMULv8i16_indexed_OP2: {
6216+
unsigned IdxDupOp =
6217+
(Pattern == MachineCombinerPattern::FMULv8i16_indexed_OP1) ? 1 : 2;
6218+
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv8i16_indexed,
6219+
&AArch64::FPR128_loRegClass, MRI);
6220+
break;
6221+
}
60856222
} // end switch (Pattern)
60866223
// Record MUL and ADD/SUB for deletion
60876224
// FIXME: This assertion fails in CodeGen/AArch64/tailmerging_in_mbp.ll and
60886225
// CodeGen/AArch64/urem-seteq-nonzero.ll.
60896226
// assert(MUL && "MUL was never set");
6090-
DelInstrs.push_back(MUL);
6227+
if (MUL)
6228+
DelInstrs.push_back(MUL);
60916229
DelInstrs.push_back(&Root);
60926230
}
60936231

llvm/test/CodeGen/AArch64/arm64-fma-combines.ll

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
; RUN: llc < %s -O=3 -mtriple=arm64-apple-ios -mcpu=cyclone -enable-unsafe-fp-math | FileCheck %s
1+
; RUN: llc < %s -O=3 -mtriple=arm64-apple-ios -mcpu=cyclone -mattr=+fullfp16 -enable-unsafe-fp-math -verify-machineinstrs | FileCheck %s
2+
23
define void @foo_2d(double* %src) {
34
; CHECK-LABEL: %entry
45
; CHECK: fmul {{d[0-9]+}}, {{d[0-9]+}}, {{d[0-9]+}}
@@ -134,3 +135,128 @@ for.body: ; preds = %for.body, %entry
134135
for.end: ; preds = %for.body
135136
ret void
136137
}
138+
139+
define void @indexed_2s(<2 x float> %shuf, <2 x float> %add,
140+
<2 x float>* %pmul, <2 x float>* %pret) {
141+
; CHECK-LABEL: %entry
142+
; CHECK: for.body
143+
; CHECK: fmla.2s {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
144+
;
145+
entry:
146+
%shuffle = shufflevector <2 x float> %shuf, <2 x float> undef, <2 x i32> zeroinitializer
147+
br label %for.body
148+
149+
for.body:
150+
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
151+
%pmul_i = getelementptr inbounds <2 x float>, <2 x float>* %pmul, i64 %i
152+
%pret_i = getelementptr inbounds <2 x float>, <2 x float>* %pret, i64 %i
153+
154+
%mul_i = load <2 x float>, <2 x float>* %pmul_i
155+
156+
%mul = fmul fast <2 x float> %mul_i, %shuffle
157+
%muladd = fadd fast <2 x float> %mul, %add
158+
159+
store <2 x float> %muladd, <2 x float>* %pret_i, align 16
160+
%inext = add i64 %i, 1
161+
br label %for.body
162+
}
163+
164+
define void @indexed_2d(<2 x double> %shuf, <2 x double> %add,
165+
<2 x double>* %pmul, <2 x double>* %pret) {
166+
; CHECK-LABEL: %entry
167+
; CHECK: for.body
168+
; CHECK: fmla.2d {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
169+
;
170+
entry:
171+
%shuffle = shufflevector <2 x double> %shuf, <2 x double> undef, <2 x i32> zeroinitializer
172+
br label %for.body
173+
174+
for.body:
175+
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
176+
%pmul_i = getelementptr inbounds <2 x double>, <2 x double>* %pmul, i64 %i
177+
%pret_i = getelementptr inbounds <2 x double>, <2 x double>* %pret, i64 %i
178+
179+
%mul_i = load <2 x double>, <2 x double>* %pmul_i
180+
181+
%mul = fmul fast <2 x double> %mul_i, %shuffle
182+
%muladd = fadd fast <2 x double> %mul, %add
183+
184+
store <2 x double> %muladd, <2 x double>* %pret_i, align 16
185+
%inext = add i64 %i, 1
186+
br label %for.body
187+
}
188+
189+
define void @indexed_4s(<4 x float> %shuf, <4 x float> %add,
190+
<4 x float>* %pmul, <4 x float>* %pret) {
191+
; CHECK-LABEL: %entry
192+
; CHECK: for.body
193+
; CHECK: fmla.4s {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
194+
;
195+
entry:
196+
%shuffle = shufflevector <4 x float> %shuf, <4 x float> undef, <4 x i32> zeroinitializer
197+
br label %for.body
198+
199+
for.body:
200+
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
201+
%pmul_i = getelementptr inbounds <4 x float>, <4 x float>* %pmul, i64 %i
202+
%pret_i = getelementptr inbounds <4 x float>, <4 x float>* %pret, i64 %i
203+
204+
%mul_i = load <4 x float>, <4 x float>* %pmul_i
205+
206+
%mul = fmul fast <4 x float> %mul_i, %shuffle
207+
%muladd = fadd fast <4 x float> %mul, %add
208+
209+
store <4 x float> %muladd, <4 x float>* %pret_i, align 16
210+
%inext = add i64 %i, 1
211+
br label %for.body
212+
}
213+
214+
define void @indexed_4h(<4 x half> %shuf, <4 x half> %add,
215+
<4 x half>* %pmul, <4 x half>* %pret) {
216+
; CHECK-LABEL: %entry
217+
; CHECK: for.body
218+
; CHECK: fmla.4h {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
219+
;
220+
entry:
221+
%shuffle = shufflevector <4 x half> %shuf, <4 x half> undef, <4 x i32> zeroinitializer
222+
br label %for.body
223+
224+
for.body:
225+
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
226+
%pmul_i = getelementptr inbounds <4 x half>, <4 x half>* %pmul, i64 %i
227+
%pret_i = getelementptr inbounds <4 x half>, <4 x half>* %pret, i64 %i
228+
229+
%mul_i = load <4 x half>, <4 x half>* %pmul_i
230+
231+
%mul = fmul fast <4 x half> %mul_i, %shuffle
232+
%muladd = fadd fast <4 x half> %mul, %add
233+
234+
store <4 x half> %muladd, <4 x half>* %pret_i, align 16
235+
%inext = add i64 %i, 1
236+
br label %for.body
237+
}
238+
239+
define void @indexed_8h(<8 x half> %shuf, <8 x half> %add,
240+
<8 x half>* %pmul, <8 x half>* %pret) {
241+
; CHECK-LABEL: %entry
242+
; CHECK: for.body
243+
; CHECK: fmla.8h {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
244+
;
245+
entry:
246+
%shuffle = shufflevector <8 x half> %shuf, <8 x half> undef, <8 x i32> zeroinitializer
247+
br label %for.body
248+
249+
for.body:
250+
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
251+
%pmul_i = getelementptr inbounds <8 x half>, <8 x half>* %pmul, i64 %i
252+
%pret_i = getelementptr inbounds <8 x half>, <8 x half>* %pret, i64 %i
253+
254+
%mul_i = load <8 x half>, <8 x half>* %pmul_i
255+
256+
%mul = fmul fast <8 x half> %mul_i, %shuffle
257+
%muladd = fadd fast <8 x half> %mul, %add
258+
259+
store <8 x half> %muladd, <8 x half>* %pret_i, align 16
260+
%inext = add i64 %i, 1
261+
br label %for.body
262+
}

0 commit comments

Comments
 (0)