Skip to content

Commit 1378fdc

Browse files
asavonicfhahn
authored andcommitted
[AArch64] Add Machine InstCombiner patterns for FMUL indexed variant
This patch adds DUP+FMUL => FMUL_indexed pattern to InstCombiner. FMUL_indexed is normally selected during instruction selection, but it does not work in cases when VDUP and VMUL are in different basic blocks. Differential Revision: https://reviews.llvm.org/D99662 (cherry-picked from b702276)
1 parent e47be78 commit 1378fdc

File tree

4 files changed

+825
-3
lines changed

4 files changed

+825
-3
lines changed

llvm/include/llvm/CodeGen/MachineCombinerPattern.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,18 @@ enum class MachineCombinerPattern {
153153
FMLSv4f32_OP1,
154154
FMLSv4f32_OP2,
155155
FMLSv4i32_indexed_OP1,
156-
FMLSv4i32_indexed_OP2
156+
FMLSv4i32_indexed_OP2,
157+
158+
FMULv2i32_indexed_OP1,
159+
FMULv2i32_indexed_OP2,
160+
FMULv2i64_indexed_OP1,
161+
FMULv2i64_indexed_OP2,
162+
FMULv4i16_indexed_OP1,
163+
FMULv4i16_indexed_OP2,
164+
FMULv4i32_indexed_OP1,
165+
FMULv4i32_indexed_OP2,
166+
FMULv8i16_indexed_OP1,
167+
FMULv8i16_indexed_OP2,
157168
};
158169

159170
} // end namespace llvm

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4880,6 +4880,55 @@ static bool getFMAPatterns(MachineInstr &Root,
48804880
return Found;
48814881
}
48824882

4883+
static bool getFMULPatterns(MachineInstr &Root,
4884+
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
4885+
MachineBasicBlock &MBB = *Root.getParent();
4886+
bool Found = false;
4887+
4888+
auto Match = [&](unsigned Opcode, int Operand,
4889+
MachineCombinerPattern Pattern) -> bool {
4890+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
4891+
MachineOperand &MO = Root.getOperand(Operand);
4892+
MachineInstr *MI = nullptr;
4893+
if (MO.isReg() && Register::isVirtualRegister(MO.getReg()))
4894+
MI = MRI.getUniqueVRegDef(MO.getReg());
4895+
if (MI && MI->getOpcode() == Opcode) {
4896+
Patterns.push_back(Pattern);
4897+
return true;
4898+
}
4899+
return false;
4900+
};
4901+
4902+
typedef MachineCombinerPattern MCP;
4903+
4904+
switch (Root.getOpcode()) {
4905+
default:
4906+
return false;
4907+
case AArch64::FMULv2f32:
4908+
Found = Match(AArch64::DUPv2i32lane, 1, MCP::FMULv2i32_indexed_OP1);
4909+
Found |= Match(AArch64::DUPv2i32lane, 2, MCP::FMULv2i32_indexed_OP2);
4910+
break;
4911+
case AArch64::FMULv2f64:
4912+
Found = Match(AArch64::DUPv2i64lane, 1, MCP::FMULv2i64_indexed_OP1);
4913+
Found |= Match(AArch64::DUPv2i64lane, 2, MCP::FMULv2i64_indexed_OP2);
4914+
break;
4915+
case AArch64::FMULv4f16:
4916+
Found = Match(AArch64::DUPv4i16lane, 1, MCP::FMULv4i16_indexed_OP1);
4917+
Found |= Match(AArch64::DUPv4i16lane, 2, MCP::FMULv4i16_indexed_OP2);
4918+
break;
4919+
case AArch64::FMULv4f32:
4920+
Found = Match(AArch64::DUPv4i32lane, 1, MCP::FMULv4i32_indexed_OP1);
4921+
Found |= Match(AArch64::DUPv4i32lane, 2, MCP::FMULv4i32_indexed_OP2);
4922+
break;
4923+
case AArch64::FMULv8f16:
4924+
Found = Match(AArch64::DUPv8i16lane, 1, MCP::FMULv8i16_indexed_OP1);
4925+
Found |= Match(AArch64::DUPv8i16lane, 2, MCP::FMULv8i16_indexed_OP2);
4926+
break;
4927+
}
4928+
4929+
return Found;
4930+
}
4931+
48834932
/// Return true when a code sequence can improve throughput. It
48844933
/// should be called only for instructions in loops.
48854934
/// \param Pattern - combiner pattern
@@ -4943,6 +4992,16 @@ bool AArch64InstrInfo::isThroughputPattern(
49434992
case MachineCombinerPattern::FMLSv2f64_OP2:
49444993
case MachineCombinerPattern::FMLSv4i32_indexed_OP2:
49454994
case MachineCombinerPattern::FMLSv4f32_OP2:
4995+
case MachineCombinerPattern::FMULv2i32_indexed_OP1:
4996+
case MachineCombinerPattern::FMULv2i32_indexed_OP2:
4997+
case MachineCombinerPattern::FMULv2i64_indexed_OP1:
4998+
case MachineCombinerPattern::FMULv2i64_indexed_OP2:
4999+
case MachineCombinerPattern::FMULv4i16_indexed_OP1:
5000+
case MachineCombinerPattern::FMULv4i16_indexed_OP2:
5001+
case MachineCombinerPattern::FMULv4i32_indexed_OP1:
5002+
case MachineCombinerPattern::FMULv4i32_indexed_OP2:
5003+
case MachineCombinerPattern::FMULv8i16_indexed_OP1:
5004+
case MachineCombinerPattern::FMULv8i16_indexed_OP2:
49465005
case MachineCombinerPattern::MULADDv8i8_OP1:
49475006
case MachineCombinerPattern::MULADDv8i8_OP2:
49485007
case MachineCombinerPattern::MULADDv16i8_OP1:
@@ -4999,6 +5058,8 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
49995058
if (getMaddPatterns(Root, Patterns))
50005059
return true;
50015060
// Floating point patterns
5061+
if (getFMULPatterns(Root, Patterns))
5062+
return true;
50025063
if (getFMAPatterns(Root, Patterns))
50035064
return true;
50045065

@@ -5087,6 +5148,42 @@ genFusedMultiply(MachineFunction &MF, MachineRegisterInfo &MRI,
50875148
return MUL;
50885149
}
50895150

5151+
/// Fold (FMUL x (DUP y lane)) into (FMUL_indexed x y lane)
5152+
static MachineInstr *
5153+
genIndexedMultiply(MachineInstr &Root,
5154+
SmallVectorImpl<MachineInstr *> &InsInstrs,
5155+
unsigned IdxDupOp, unsigned MulOpc,
5156+
const TargetRegisterClass *RC, MachineRegisterInfo &MRI) {
5157+
assert(((IdxDupOp == 1) || (IdxDupOp == 2)) &&
5158+
"Invalid index of FMUL operand");
5159+
5160+
MachineFunction &MF = *Root.getMF();
5161+
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
5162+
5163+
MachineInstr *Dup =
5164+
MF.getRegInfo().getUniqueVRegDef(Root.getOperand(IdxDupOp).getReg());
5165+
5166+
Register DupSrcReg = Dup->getOperand(1).getReg();
5167+
MRI.clearKillFlags(DupSrcReg);
5168+
MRI.constrainRegClass(DupSrcReg, RC);
5169+
5170+
unsigned DupSrcLane = Dup->getOperand(2).getImm();
5171+
5172+
unsigned IdxMulOp = IdxDupOp == 1 ? 2 : 1;
5173+
MachineOperand &MulOp = Root.getOperand(IdxMulOp);
5174+
5175+
Register ResultReg = Root.getOperand(0).getReg();
5176+
5177+
MachineInstrBuilder MIB;
5178+
MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(MulOpc), ResultReg)
5179+
.add(MulOp)
5180+
.addReg(DupSrcReg)
5181+
.addImm(DupSrcLane);
5182+
5183+
InsInstrs.push_back(MIB);
5184+
return &Root;
5185+
}
5186+
50905187
/// genFusedMultiplyAcc - Helper to generate fused multiply accumulate
50915188
/// instructions.
50925189
///
@@ -6045,12 +6142,53 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
60456142
}
60466143
break;
60476144
}
6145+
case MachineCombinerPattern::FMULv2i32_indexed_OP1:
6146+
case MachineCombinerPattern::FMULv2i32_indexed_OP2: {
6147+
unsigned IdxDupOp =
6148+
(Pattern == MachineCombinerPattern::FMULv2i32_indexed_OP1) ? 1 : 2;
6149+
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv2i32_indexed,
6150+
&AArch64::FPR128RegClass, MRI);
6151+
break;
6152+
}
6153+
case MachineCombinerPattern::FMULv2i64_indexed_OP1:
6154+
case MachineCombinerPattern::FMULv2i64_indexed_OP2: {
6155+
unsigned IdxDupOp =
6156+
(Pattern == MachineCombinerPattern::FMULv2i64_indexed_OP1) ? 1 : 2;
6157+
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv2i64_indexed,
6158+
&AArch64::FPR128RegClass, MRI);
6159+
break;
6160+
}
6161+
case MachineCombinerPattern::FMULv4i16_indexed_OP1:
6162+
case MachineCombinerPattern::FMULv4i16_indexed_OP2: {
6163+
unsigned IdxDupOp =
6164+
(Pattern == MachineCombinerPattern::FMULv4i16_indexed_OP1) ? 1 : 2;
6165+
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv4i16_indexed,
6166+
&AArch64::FPR128_loRegClass, MRI);
6167+
break;
6168+
}
6169+
case MachineCombinerPattern::FMULv4i32_indexed_OP1:
6170+
case MachineCombinerPattern::FMULv4i32_indexed_OP2: {
6171+
unsigned IdxDupOp =
6172+
(Pattern == MachineCombinerPattern::FMULv4i32_indexed_OP1) ? 1 : 2;
6173+
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv4i32_indexed,
6174+
&AArch64::FPR128RegClass, MRI);
6175+
break;
6176+
}
6177+
case MachineCombinerPattern::FMULv8i16_indexed_OP1:
6178+
case MachineCombinerPattern::FMULv8i16_indexed_OP2: {
6179+
unsigned IdxDupOp =
6180+
(Pattern == MachineCombinerPattern::FMULv8i16_indexed_OP1) ? 1 : 2;
6181+
genIndexedMultiply(Root, InsInstrs, IdxDupOp, AArch64::FMULv8i16_indexed,
6182+
&AArch64::FPR128_loRegClass, MRI);
6183+
break;
6184+
}
60486185
} // end switch (Pattern)
60496186
// Record MUL and ADD/SUB for deletion
60506187
// FIXME: This assertion fails in CodeGen/AArch64/tailmerging_in_mbp.ll and
60516188
// CodeGen/AArch64/urem-seteq-nonzero.ll.
60526189
// assert(MUL && "MUL was never set");
6053-
DelInstrs.push_back(MUL);
6190+
if (MUL)
6191+
DelInstrs.push_back(MUL);
60546192
DelInstrs.push_back(&Root);
60556193
}
60566194

llvm/test/CodeGen/AArch64/arm64-fma-combines.ll

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
; RUN: llc < %s -O=3 -mtriple=arm64-apple-ios -mcpu=cyclone -enable-unsafe-fp-math | FileCheck %s
1+
; RUN: llc < %s -O=3 -mtriple=arm64-apple-ios -mcpu=cyclone -mattr=+fullfp16 -enable-unsafe-fp-math -verify-machineinstrs | FileCheck %s
2+
23
define void @foo_2d(double* %src) {
34
; CHECK-LABEL: %entry
45
; CHECK: fmul {{d[0-9]+}}, {{d[0-9]+}}, {{d[0-9]+}}
@@ -134,3 +135,128 @@ for.body: ; preds = %for.body, %entry
134135
for.end: ; preds = %for.body
135136
ret void
136137
}
138+
139+
define void @indexed_2s(<2 x float> %shuf, <2 x float> %add,
140+
<2 x float>* %pmul, <2 x float>* %pret) {
141+
; CHECK-LABEL: %entry
142+
; CHECK: for.body
143+
; CHECK: fmla.2s {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
144+
;
145+
entry:
146+
%shuffle = shufflevector <2 x float> %shuf, <2 x float> undef, <2 x i32> zeroinitializer
147+
br label %for.body
148+
149+
for.body:
150+
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
151+
%pmul_i = getelementptr inbounds <2 x float>, <2 x float>* %pmul, i64 %i
152+
%pret_i = getelementptr inbounds <2 x float>, <2 x float>* %pret, i64 %i
153+
154+
%mul_i = load <2 x float>, <2 x float>* %pmul_i
155+
156+
%mul = fmul fast <2 x float> %mul_i, %shuffle
157+
%muladd = fadd fast <2 x float> %mul, %add
158+
159+
store <2 x float> %muladd, <2 x float>* %pret_i, align 16
160+
%inext = add i64 %i, 1
161+
br label %for.body
162+
}
163+
164+
define void @indexed_2d(<2 x double> %shuf, <2 x double> %add,
165+
<2 x double>* %pmul, <2 x double>* %pret) {
166+
; CHECK-LABEL: %entry
167+
; CHECK: for.body
168+
; CHECK: fmla.2d {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
169+
;
170+
entry:
171+
%shuffle = shufflevector <2 x double> %shuf, <2 x double> undef, <2 x i32> zeroinitializer
172+
br label %for.body
173+
174+
for.body:
175+
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
176+
%pmul_i = getelementptr inbounds <2 x double>, <2 x double>* %pmul, i64 %i
177+
%pret_i = getelementptr inbounds <2 x double>, <2 x double>* %pret, i64 %i
178+
179+
%mul_i = load <2 x double>, <2 x double>* %pmul_i
180+
181+
%mul = fmul fast <2 x double> %mul_i, %shuffle
182+
%muladd = fadd fast <2 x double> %mul, %add
183+
184+
store <2 x double> %muladd, <2 x double>* %pret_i, align 16
185+
%inext = add i64 %i, 1
186+
br label %for.body
187+
}
188+
189+
define void @indexed_4s(<4 x float> %shuf, <4 x float> %add,
190+
<4 x float>* %pmul, <4 x float>* %pret) {
191+
; CHECK-LABEL: %entry
192+
; CHECK: for.body
193+
; CHECK: fmla.4s {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
194+
;
195+
entry:
196+
%shuffle = shufflevector <4 x float> %shuf, <4 x float> undef, <4 x i32> zeroinitializer
197+
br label %for.body
198+
199+
for.body:
200+
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
201+
%pmul_i = getelementptr inbounds <4 x float>, <4 x float>* %pmul, i64 %i
202+
%pret_i = getelementptr inbounds <4 x float>, <4 x float>* %pret, i64 %i
203+
204+
%mul_i = load <4 x float>, <4 x float>* %pmul_i
205+
206+
%mul = fmul fast <4 x float> %mul_i, %shuffle
207+
%muladd = fadd fast <4 x float> %mul, %add
208+
209+
store <4 x float> %muladd, <4 x float>* %pret_i, align 16
210+
%inext = add i64 %i, 1
211+
br label %for.body
212+
}
213+
214+
define void @indexed_4h(<4 x half> %shuf, <4 x half> %add,
215+
<4 x half>* %pmul, <4 x half>* %pret) {
216+
; CHECK-LABEL: %entry
217+
; CHECK: for.body
218+
; CHECK: fmla.4h {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
219+
;
220+
entry:
221+
%shuffle = shufflevector <4 x half> %shuf, <4 x half> undef, <4 x i32> zeroinitializer
222+
br label %for.body
223+
224+
for.body:
225+
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
226+
%pmul_i = getelementptr inbounds <4 x half>, <4 x half>* %pmul, i64 %i
227+
%pret_i = getelementptr inbounds <4 x half>, <4 x half>* %pret, i64 %i
228+
229+
%mul_i = load <4 x half>, <4 x half>* %pmul_i
230+
231+
%mul = fmul fast <4 x half> %mul_i, %shuffle
232+
%muladd = fadd fast <4 x half> %mul, %add
233+
234+
store <4 x half> %muladd, <4 x half>* %pret_i, align 16
235+
%inext = add i64 %i, 1
236+
br label %for.body
237+
}
238+
239+
define void @indexed_8h(<8 x half> %shuf, <8 x half> %add,
240+
<8 x half>* %pmul, <8 x half>* %pret) {
241+
; CHECK-LABEL: %entry
242+
; CHECK: for.body
243+
; CHECK: fmla.8h {{v[0-9]+}}, {{v[0-9]+}}, {{v[0-9]+}}[0]
244+
;
245+
entry:
246+
%shuffle = shufflevector <8 x half> %shuf, <8 x half> undef, <8 x i32> zeroinitializer
247+
br label %for.body
248+
249+
for.body:
250+
%i = phi i64 [ 0, %entry ], [ %inext, %for.body ]
251+
%pmul_i = getelementptr inbounds <8 x half>, <8 x half>* %pmul, i64 %i
252+
%pret_i = getelementptr inbounds <8 x half>, <8 x half>* %pret, i64 %i
253+
254+
%mul_i = load <8 x half>, <8 x half>* %pmul_i
255+
256+
%mul = fmul fast <8 x half> %mul_i, %shuffle
257+
%muladd = fadd fast <8 x half> %mul, %add
258+
259+
store <8 x half> %muladd, <8 x half>* %pret_i, align 16
260+
%inext = add i64 %i, 1
261+
br label %for.body
262+
}

0 commit comments

Comments
 (0)