Skip to content

Commit caa95c2

Browse files
committed
[AArch64] Emit FNMADD instead of FNEG(FMADD)
Emit FNMADD instead of FNEG(FMADD) for optimization levels above Oz when fast-math flags (nsz+contract) permit it. Differential Revision: https://reviews.llvm.org/D149260
1 parent 00ff746 commit caa95c2

File tree

3 files changed

+220
-0
lines changed

3 files changed

+220
-0
lines changed

llvm/include/llvm/CodeGen/MachineCombinerPattern.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ enum class MachineCombinerPattern {
178178

179179
// X86 VNNI
180180
DPWSSD,
181+
182+
FNMADDS,
183+
FNMADDD,
181184
};
182185

183186
} // end namespace llvm

llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5409,6 +5409,41 @@ static bool getFMULPatterns(MachineInstr &Root,
54095409
return Found;
54105410
}
54115411

5412+
static bool getFNEGPatterns(MachineInstr &Root,
5413+
SmallVectorImpl<MachineCombinerPattern> &Patterns) {
5414+
unsigned Opc = Root.getOpcode();
5415+
MachineBasicBlock &MBB = *Root.getParent();
5416+
MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
5417+
bool Found = false;
5418+
5419+
auto Match = [&](unsigned Opcode, MachineCombinerPattern Pattern) -> bool {
5420+
MachineOperand &MO = Root.getOperand(1);
5421+
MachineInstr *MI = MRI.getUniqueVRegDef(MO.getReg());
5422+
if ((MI->getOpcode() == Opcode) &&
5423+
Root.getFlag(MachineInstr::MIFlag::FmContract) &&
5424+
Root.getFlag(MachineInstr::MIFlag::FmNsz) &&
5425+
MI->getFlag(MachineInstr::MIFlag::FmContract) &&
5426+
MI->getFlag(MachineInstr::MIFlag::FmNsz)) {
5427+
Patterns.push_back(Pattern);
5428+
return true;
5429+
}
5430+
return false;
5431+
};
5432+
5433+
switch (Opc) {
5434+
default:
5435+
return false;
5436+
case AArch64::FNEGDr:
5437+
Found |= Match(AArch64::FMADDDrrr, MachineCombinerPattern::FNMADDD);
5438+
break;
5439+
case AArch64::FNEGSr:
5440+
Found |= Match(AArch64::FMADDSrrr, MachineCombinerPattern::FNMADDS);
5441+
break;
5442+
}
5443+
5444+
return Found;
5445+
}
5446+
54125447
/// Return true when a code sequence can improve throughput. It
54135448
/// should be called only for instructions in loops.
54145449
/// \param Pattern - combiner pattern
@@ -5578,6 +5613,8 @@ bool AArch64InstrInfo::getMachineCombinerPatterns(
55785613
return true;
55795614
if (getFMAPatterns(Root, Patterns))
55805615
return true;
5616+
if (getFNEGPatterns(Root, Patterns))
5617+
return true;
55815618

55825619
// Other patterns
55835620
if (getMiscPatterns(Root, Patterns))
@@ -5668,6 +5705,39 @@ genFusedMultiply(MachineFunction &MF, MachineRegisterInfo &MRI,
56685705
return MUL;
56695706
}
56705707

5708+
static MachineInstr *
5709+
genFNegatedMAD(MachineFunction &MF, MachineRegisterInfo &MRI,
5710+
const TargetInstrInfo *TII, MachineInstr &Root,
5711+
SmallVectorImpl<MachineInstr *> &InsInstrs, unsigned Opc,
5712+
const TargetRegisterClass *RC) {
5713+
MachineInstr *MAD = MRI.getUniqueVRegDef(Root.getOperand(1).getReg());
5714+
Register ResultReg = Root.getOperand(0).getReg();
5715+
Register SrcReg0 = MAD->getOperand(1).getReg();
5716+
Register SrcReg1 = MAD->getOperand(2).getReg();
5717+
Register SrcReg2 = MAD->getOperand(3).getReg();
5718+
bool Src0IsKill = MAD->getOperand(1).isKill();
5719+
bool Src1IsKill = MAD->getOperand(2).isKill();
5720+
bool Src2IsKill = MAD->getOperand(3).isKill();
5721+
5722+
if (ResultReg.isVirtual())
5723+
MRI.constrainRegClass(ResultReg, RC);
5724+
if (SrcReg0.isVirtual())
5725+
MRI.constrainRegClass(SrcReg0, RC);
5726+
if (SrcReg1.isVirtual())
5727+
MRI.constrainRegClass(SrcReg1, RC);
5728+
if (SrcReg2.isVirtual())
5729+
MRI.constrainRegClass(SrcReg2, RC);
5730+
5731+
MachineInstrBuilder MIB =
5732+
BuildMI(MF, MIMetadata(Root), TII->get(Opc), ResultReg)
5733+
.addReg(SrcReg0, getKillRegState(Src0IsKill))
5734+
.addReg(SrcReg1, getKillRegState(Src1IsKill))
5735+
.addReg(SrcReg2, getKillRegState(Src2IsKill));
5736+
InsInstrs.push_back(MIB);
5737+
5738+
return MAD;
5739+
}
5740+
56715741
/// Fold (FMUL x (DUP y lane)) into (FMUL_indexed x y lane)
56725742
static MachineInstr *
56735743
genIndexedMultiply(MachineInstr &Root,
@@ -5894,6 +5964,7 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
58945964
const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo();
58955965

58965966
MachineInstr *MUL = nullptr;
5967+
MachineInstr *MAD = nullptr;
58975968
const TargetRegisterClass *RC;
58985969
unsigned Opc;
58995970
switch (Pattern) {
@@ -6800,6 +6871,20 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
68006871
&AArch64::FPR128_loRegClass, MRI);
68016872
break;
68026873
}
6874+
6875+
case MachineCombinerPattern::FNMADDS: {
6876+
Opc = AArch64::FNMADDSrrr;
6877+
RC = &AArch64::FPR32RegClass;
6878+
MAD = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs, Opc, RC);
6879+
break;
6880+
}
6881+
case MachineCombinerPattern::FNMADDD: {
6882+
Opc = AArch64::FNMADDDrrr;
6883+
RC = &AArch64::FPR64RegClass;
6884+
MAD = genFNegatedMAD(MF, MRI, TII, Root, InsInstrs, Opc, RC);
6885+
break;
6886+
}
6887+
68036888
} // end switch (Pattern)
68046889
// Record MUL and ADD/SUB for deletion
68056890
if (MUL)
@@ -6811,6 +6896,8 @@ void AArch64InstrInfo::genAlternativeCodeSequence(
68116896
uint16_t Flags = Root.getFlags();
68126897
if (MUL)
68136898
Flags = Root.mergeFlagsWith(*MUL);
6899+
if (MAD)
6900+
Flags = Root.mergeFlagsWith(*MAD);
68146901
for (auto *MI : InsInstrs)
68156902
MI->setFlags(Flags);
68166903
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2
2+
; RUN: llc < %s -mtriple=aarch64-linux-gnu -O3 | FileCheck %s
3+
4+
define void @fnmaddd(ptr %a, ptr %b, ptr %c) {
5+
; CHECK-LABEL: fnmaddd:
6+
; CHECK: // %bb.0: // %entry
7+
; CHECK-NEXT: ldr d0, [x1]
8+
; CHECK-NEXT: ldr d1, [x0]
9+
; CHECK-NEXT: ldr d2, [x2]
10+
; CHECK-NEXT: fnmadd d0, d0, d1, d2
11+
; CHECK-NEXT: str d0, [x0]
12+
; CHECK-NEXT: ret
13+
entry:
14+
%0 = load double, ptr %a, align 8
15+
%1 = load double, ptr %b, align 8
16+
%mul = fmul fast double %1, %0
17+
%2 = load double, ptr %c, align 8
18+
%add = fadd fast double %mul, %2
19+
%fneg = fneg fast double %add
20+
store double %fneg, ptr %a, align 8
21+
ret void
22+
}
23+
24+
; Don't combine: No flags
25+
define void @fnmaddd_no_fast(ptr %a, ptr %b, ptr %c) {
26+
; CHECK-LABEL: fnmaddd_no_fast:
27+
; CHECK: // %bb.0: // %entry
28+
; CHECK-NEXT: ldr d0, [x0]
29+
; CHECK-NEXT: ldr d1, [x1]
30+
; CHECK-NEXT: fmul d0, d1, d0
31+
; CHECK-NEXT: ldr d1, [x2]
32+
; CHECK-NEXT: fadd d0, d0, d1
33+
; CHECK-NEXT: fneg d0, d0
34+
; CHECK-NEXT: str d0, [x0]
35+
; CHECK-NEXT: ret
36+
entry:
37+
%0 = load double, ptr %a, align 8
38+
%1 = load double, ptr %b, align 8
39+
%mul = fmul double %1, %0
40+
%2 = load double, ptr %c, align 8
41+
%add = fadd double %mul, %2
42+
%fneg = fneg double %add
43+
store double %fneg, ptr %a, align 8
44+
ret void
45+
}
46+
47+
define void @fnmadds(ptr %a, ptr %b, ptr %c) {
48+
; CHECK-LABEL: fnmadds:
49+
; CHECK: // %bb.0: // %entry
50+
; CHECK-NEXT: ldr s0, [x1]
51+
; CHECK-NEXT: ldr s1, [x0]
52+
; CHECK-NEXT: ldr s2, [x2]
53+
; CHECK-NEXT: fnmadd s0, s0, s1, s2
54+
; CHECK-NEXT: str s0, [x0]
55+
; CHECK-NEXT: ret
56+
entry:
57+
%0 = load float, ptr %a, align 4
58+
%1 = load float, ptr %b, align 4
59+
%mul = fmul fast float %1, %0
60+
%2 = load float, ptr %c, align 4
61+
%add = fadd fast float %mul, %2
62+
%fneg = fneg fast float %add
63+
store float %fneg, ptr %a, align 4
64+
ret void
65+
}
66+
67+
define void @fnmadds_nsz_contract(ptr %a, ptr %b, ptr %c) {
68+
; CHECK-LABEL: fnmadds_nsz_contract:
69+
; CHECK: // %bb.0: // %entry
70+
; CHECK-NEXT: ldr s0, [x1]
71+
; CHECK-NEXT: ldr s1, [x0]
72+
; CHECK-NEXT: ldr s2, [x2]
73+
; CHECK-NEXT: fnmadd s0, s0, s1, s2
74+
; CHECK-NEXT: str s0, [x0]
75+
; CHECK-NEXT: ret
76+
entry:
77+
%0 = load float, ptr %a, align 4
78+
%1 = load float, ptr %b, align 4
79+
%mul = fmul contract nsz float %1, %0
80+
%2 = load float, ptr %c, align 4
81+
%add = fadd contract nsz float %mul, %2
82+
%fneg = fneg contract nsz float %add
83+
store float %fneg, ptr %a, align 4
84+
ret void
85+
}
86+
87+
; Don't combine: Missing nsz
88+
define void @fnmadds_contract(ptr %a, ptr %b, ptr %c) {
89+
; CHECK-LABEL: fnmadds_contract:
90+
; CHECK: // %bb.0: // %entry
91+
; CHECK-NEXT: ldr s0, [x1]
92+
; CHECK-NEXT: ldr s1, [x0]
93+
; CHECK-NEXT: ldr s2, [x2]
94+
; CHECK-NEXT: fmadd s0, s0, s1, s2
95+
; CHECK-NEXT: fneg s0, s0
96+
; CHECK-NEXT: str s0, [x0]
97+
; CHECK-NEXT: ret
98+
entry:
99+
%0 = load float, ptr %a, align 4
100+
%1 = load float, ptr %b, align 4
101+
%mul = fmul contract float %1, %0
102+
%2 = load float, ptr %c, align 4
103+
%add = fadd contract float %mul, %2
104+
%fneg = fneg contract float %add
105+
store float %fneg, ptr %a, align 4
106+
ret void
107+
}
108+
109+
; Don't combine: Missing contract
110+
define void @fnmadds_nsz(ptr %a, ptr %b, ptr %c) {
111+
; CHECK-LABEL: fnmadds_nsz:
112+
; CHECK: // %bb.0: // %entry
113+
; CHECK-NEXT: ldr s0, [x0]
114+
; CHECK-NEXT: ldr s1, [x1]
115+
; CHECK-NEXT: fmul s0, s1, s0
116+
; CHECK-NEXT: ldr s1, [x2]
117+
; CHECK-NEXT: fadd s0, s0, s1
118+
; CHECK-NEXT: fneg s0, s0
119+
; CHECK-NEXT: str s0, [x0]
120+
; CHECK-NEXT: ret
121+
entry:
122+
%0 = load float, ptr %a, align 4
123+
%1 = load float, ptr %b, align 4
124+
%mul = fmul nsz float %1, %0
125+
%2 = load float, ptr %c, align 4
126+
%add = fadd nsz float %mul, %2
127+
%fneg = fneg nsz float %add
128+
store float %fneg, ptr %a, align 4
129+
ret void
130+
}

0 commit comments

Comments
 (0)