Skip to content

Commit 8bd327d

Browse files
authored
[AMDGPU][GlobalISel] Add fdiv / sqrt to rsq combine (#78673)
Fixes #64743
1 parent e899641 commit 8bd327d

File tree

3 files changed

+614
-1
lines changed

3 files changed

+614
-1
lines changed

llvm/lib/Target/AMDGPU/AMDGPUCombine.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ def rcp_sqrt_to_rsq : GICombineRule<
3333
[{ return matchRcpSqrtToRsq(*${rcp}, ${matchinfo}); }]),
3434
(apply [{ Helper.applyBuildFn(*${rcp}, ${matchinfo}); }])>;
3535

36+
def fdiv_by_sqrt_to_rsq_f16 : GICombineRule<
37+
(defs root:$root),
38+
(match (G_FSQRT f16:$sqrt, $x, (MIFlags FmContract)),
39+
(G_FDIV f16:$dst, $y, $sqrt, (MIFlags FmContract)):$root,
40+
[{ return matchFDivSqrtToRsqF16(*${root}); }]),
41+
(apply [{ applyFDivSqrtToRsqF16(*${root}, ${x}.getReg()); }])>;
3642

3743
def cvt_f32_ubyteN_matchdata : GIDefMatchData<"CvtF32UByteMatchInfo">;
3844

@@ -156,7 +162,7 @@ def AMDGPUPostLegalizerCombiner: GICombiner<
156162
"AMDGPUPostLegalizerCombinerImpl",
157163
[all_combines, gfx6gfx7_combines, gfx8_combines,
158164
uchar_to_float, cvt_f32_ubyteN, remove_fcanonicalize, foldable_fneg,
159-
rcp_sqrt_to_rsq, sign_extension_in_reg, smulu64]> {
165+
rcp_sqrt_to_rsq, fdiv_by_sqrt_to_rsq_f16, sign_extension_in_reg, smulu64]> {
160166
let CombineAllMethodName = "tryCombineAllImpl";
161167
}
162168

llvm/lib/Target/AMDGPU/AMDGPUPostLegalizerCombiner.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ class AMDGPUPostLegalizerCombinerImpl : public Combiner {
8383
matchRcpSqrtToRsq(MachineInstr &MI,
8484
std::function<void(MachineIRBuilder &)> &MatchInfo) const;
8585

86+
bool matchFDivSqrtToRsqF16(MachineInstr &MI) const;
87+
void applyFDivSqrtToRsqF16(MachineInstr &MI, const Register &X) const;
88+
8689
// FIXME: Should be able to have 2 separate matchdatas rather than custom
8790
// struct boilerplate.
8891
struct CvtF32UByteMatchInfo {
@@ -334,6 +337,26 @@ bool AMDGPUPostLegalizerCombinerImpl::matchRcpSqrtToRsq(
334337
return false;
335338
}
336339

340+
bool AMDGPUPostLegalizerCombinerImpl::matchFDivSqrtToRsqF16(
341+
MachineInstr &MI) const {
342+
Register Sqrt = MI.getOperand(2).getReg();
343+
return MRI.hasOneNonDBGUse(Sqrt);
344+
}
345+
346+
void AMDGPUPostLegalizerCombinerImpl::applyFDivSqrtToRsqF16(
347+
MachineInstr &MI, const Register &X) const {
348+
Register Dst = MI.getOperand(0).getReg();
349+
Register Y = MI.getOperand(1).getReg();
350+
LLT DstTy = MRI.getType(Dst);
351+
uint32_t Flags = MI.getFlags();
352+
Register RSQ = B.buildIntrinsic(Intrinsic::amdgcn_rsq, {DstTy})
353+
.addUse(X)
354+
.setMIFlags(Flags)
355+
.getReg(0);
356+
B.buildFMul(Dst, RSQ, Y, Flags);
357+
MI.eraseFromParent();
358+
}
359+
337360
bool AMDGPUPostLegalizerCombinerImpl::matchCvtF32UByteN(
338361
MachineInstr &MI, CvtF32UByteMatchInfo &MatchInfo) const {
339362
Register SrcReg = MI.getOperand(1).getReg();

0 commit comments

Comments
 (0)