Skip to content

Commit 767999f

Browse files
committed
[AMDGPU][GlobalISel] Support mad/fma_mix selection
Adds support for selecting the following instructions using GlobalISel: - v_mad_mix/v_fma_mix - v_mad_mixhi/v_fma_mixhi - v_mad_mixlo/v_fma_mixlo To select those instructions properly, some additional changes were needed which impacted other tests as well. Reviewed By: arsenm Differential Revision: https://reviews.llvm.org/D134354
1 parent d1f90b6 commit 767999f

11 files changed

+3375
-1717
lines changed

llvm/lib/Target/AMDGPU/AMDGPUGISel.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def gi_smrd_buffer_sgpr_imm :
153153
GIComplexOperandMatcher<s64, "selectSMRDBufferSgprImm">,
154154
GIComplexPatternEquiv<SMRDBufferSgprImm>;
155155

156+
def gi_vop3_mad_mix_mods :
157+
GIComplexOperandMatcher<s64, "selectVOP3PMadMixMods">,
158+
GIComplexPatternEquiv<VOP3PMadMixMods>;
159+
156160
// Separate load nodes are defined to glue m0 initialization in
157161
// SelectionDAG. The GISel selector can just insert m0 initialization
158162
// directly before selecting a glue-less load, so hide this

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,60 @@ bool AMDGPUInstructionSelector::selectG_EXTRACT(MachineInstr &I) const {
522522
return true;
523523
}
524524

525+
bool AMDGPUInstructionSelector::selectG_FMA_FMAD(MachineInstr &I) const {
526+
assert(I.getOpcode() == AMDGPU::G_FMA || I.getOpcode() == AMDGPU::G_FMAD);
527+
528+
// Try to manually select MAD_MIX/FMA_MIX.
529+
Register Dst = I.getOperand(0).getReg();
530+
LLT ResultTy = MRI->getType(Dst);
531+
bool IsFMA = I.getOpcode() == AMDGPU::G_FMA;
532+
if (ResultTy != LLT::scalar(32) ||
533+
(IsFMA ? !Subtarget->hasFmaMixInsts() : !Subtarget->hasMadMixInsts()))
534+
return false;
535+
536+
// Avoid using v_mad_mix_f32/v_fma_mix_f32 unless there is actually an operand
537+
// using the conversion from f16.
538+
bool MatchedSrc0, MatchedSrc1, MatchedSrc2;
539+
auto [Src0, Src0Mods] =
540+
selectVOP3PMadMixModsImpl(I.getOperand(1), MatchedSrc0);
541+
auto [Src1, Src1Mods] =
542+
selectVOP3PMadMixModsImpl(I.getOperand(2), MatchedSrc1);
543+
auto [Src2, Src2Mods] =
544+
selectVOP3PMadMixModsImpl(I.getOperand(3), MatchedSrc2);
545+
546+
#ifndef NDEBUG
547+
const SIMachineFunctionInfo *MFI =
548+
I.getMF()->getInfo<SIMachineFunctionInfo>();
549+
AMDGPU::SIModeRegisterDefaults Mode = MFI->getMode();
550+
assert((IsFMA || !Mode.allFP32Denormals()) &&
551+
"fmad selected with denormals enabled");
552+
#endif
553+
554+
// TODO: We can select this with f32 denormals enabled if all the sources are
555+
// converted from f16 (in which case fmad isn't legal).
556+
if (!MatchedSrc0 && !MatchedSrc1 && !MatchedSrc2)
557+
return false;
558+
559+
const unsigned OpC = IsFMA ? AMDGPU::V_FMA_MIX_F32 : AMDGPU::V_MAD_MIX_F32;
560+
MachineInstr *MixInst =
561+
BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(OpC), Dst)
562+
.addImm(Src0Mods)
563+
.addReg(Src0)
564+
.addImm(Src1Mods)
565+
.addReg(Src1)
566+
.addImm(Src2Mods)
567+
.addReg(Src2)
568+
.addImm(0)
569+
.addImm(0)
570+
.addImm(0);
571+
572+
if (!constrainSelectedInstRegOperands(*MixInst, TII, TRI, RBI))
573+
return false;
574+
575+
I.eraseFromParent();
576+
return true;
577+
}
578+
525579
bool AMDGPUInstructionSelector::selectG_MERGE_VALUES(MachineInstr &MI) const {
526580
MachineBasicBlock *BB = MI.getParent();
527581
Register DstReg = MI.getOperand(0).getReg();
@@ -3228,6 +3282,11 @@ bool AMDGPUInstructionSelector::select(MachineInstr &I) {
32283282
return selectG_FABS(I);
32293283
case TargetOpcode::G_EXTRACT:
32303284
return selectG_EXTRACT(I);
3285+
case TargetOpcode::G_FMA:
3286+
case TargetOpcode::G_FMAD:
3287+
if (selectG_FMA_FMAD(I))
3288+
return true;
3289+
return selectImpl(I, *CoverageInfo);
32313290
case TargetOpcode::G_MERGE_VALUES:
32323291
case TargetOpcode::G_CONCAT_VECTORS:
32333292
return selectG_MERGE_VALUES(I);
@@ -4679,6 +4738,137 @@ AMDGPUInstructionSelector::selectSMRDBufferSgprImm(MachineOperand &Root) const {
46794738
[=](MachineInstrBuilder &MIB) { MIB.addImm(*EncodedOffset); }}};
46804739
}
46814740

4741+
// Variant of stripBitCast that returns the instruction instead of a
4742+
// MachineOperand.
4743+
static MachineInstr *stripBitCast(MachineInstr *MI, MachineRegisterInfo &MRI) {
4744+
if (MI->getOpcode() == AMDGPU::G_BITCAST)
4745+
return getDefIgnoringCopies(MI->getOperand(1).getReg(), MRI);
4746+
return MI;
4747+
}
4748+
4749+
// Figure out if this is really an extract of the high 16-bits of a dword,
4750+
// returns nullptr if it isn't.
4751+
static MachineInstr *isExtractHiElt(MachineInstr *Inst,
4752+
MachineRegisterInfo &MRI) {
4753+
Inst = stripBitCast(Inst, MRI);
4754+
4755+
if (Inst->getOpcode() != AMDGPU::G_TRUNC)
4756+
return nullptr;
4757+
4758+
MachineInstr *TruncOp =
4759+
getDefIgnoringCopies(Inst->getOperand(1).getReg(), MRI);
4760+
TruncOp = stripBitCast(TruncOp, MRI);
4761+
4762+
// G_LSHR x, (G_CONSTANT i32 16)
4763+
if (TruncOp->getOpcode() == AMDGPU::G_LSHR) {
4764+
auto SrlAmount = getIConstantVRegValWithLookThrough(
4765+
TruncOp->getOperand(2).getReg(), MRI);
4766+
if (SrlAmount && SrlAmount->Value.getZExtValue() == 16) {
4767+
MachineInstr *SrlOp =
4768+
getDefIgnoringCopies(TruncOp->getOperand(1).getReg(), MRI);
4769+
return stripBitCast(SrlOp, MRI);
4770+
}
4771+
}
4772+
4773+
// G_SHUFFLE_VECTOR x, y, shufflemask(1, 1|0)
4774+
// 1, 0 swaps the low/high 16 bits.
4775+
// 1, 1 sets the high 16 bits to be the same as the low 16.
4776+
// in any case, it selects the high elts.
4777+
if (TruncOp->getOpcode() == AMDGPU::G_SHUFFLE_VECTOR) {
4778+
assert(MRI.getType(TruncOp->getOperand(0).getReg()) ==
4779+
LLT::fixed_vector(2, 16));
4780+
4781+
ArrayRef<int> Mask = TruncOp->getOperand(3).getShuffleMask();
4782+
assert(Mask.size() == 2);
4783+
4784+
if (Mask[0] == 1 && Mask[1] <= 1) {
4785+
MachineInstr *LHS =
4786+
getDefIgnoringCopies(TruncOp->getOperand(1).getReg(), MRI);
4787+
return stripBitCast(LHS, MRI);
4788+
}
4789+
}
4790+
4791+
return nullptr;
4792+
}
4793+
4794+
std::pair<Register, unsigned>
4795+
AMDGPUInstructionSelector::selectVOP3PMadMixModsImpl(MachineOperand &Root,
4796+
bool &Matched) const {
4797+
Matched = false;
4798+
4799+
Register Src;
4800+
unsigned Mods;
4801+
std::tie(Src, Mods) = selectVOP3ModsImpl(Root);
4802+
4803+
MachineInstr *MI = getDefIgnoringCopies(Src, *MRI);
4804+
if (MI->getOpcode() == AMDGPU::G_FPEXT) {
4805+
MachineOperand *MO = &MI->getOperand(1);
4806+
Src = MO->getReg();
4807+
MI = getDefIgnoringCopies(Src, *MRI);
4808+
4809+
assert(MRI->getType(Src) == LLT::scalar(16));
4810+
4811+
// See through bitcasts.
4812+
// FIXME: Would be nice to use stripBitCast here.
4813+
if (MI->getOpcode() == AMDGPU::G_BITCAST) {
4814+
MO = &MI->getOperand(1);
4815+
Src = MO->getReg();
4816+
MI = getDefIgnoringCopies(Src, *MRI);
4817+
}
4818+
4819+
const auto CheckAbsNeg = [&]() {
4820+
// Be careful about folding modifiers if we already have an abs. fneg is
4821+
// applied last, so we don't want to apply an earlier fneg.
4822+
if ((Mods & SISrcMods::ABS) == 0) {
4823+
unsigned ModsTmp;
4824+
std::tie(Src, ModsTmp) = selectVOP3ModsImpl(*MO);
4825+
MI = getDefIgnoringCopies(Src, *MRI);
4826+
4827+
if ((ModsTmp & SISrcMods::NEG) != 0)
4828+
Mods ^= SISrcMods::NEG;
4829+
4830+
if ((ModsTmp & SISrcMods::ABS) != 0)
4831+
Mods |= SISrcMods::ABS;
4832+
}
4833+
};
4834+
4835+
CheckAbsNeg();
4836+
4837+
// op_sel/op_sel_hi decide the source type and source.
4838+
// If the source's op_sel_hi is set, it indicates to do a conversion from
4839+
// fp16. If the sources's op_sel is set, it picks the high half of the
4840+
// source register.
4841+
4842+
Mods |= SISrcMods::OP_SEL_1;
4843+
4844+
if (MachineInstr *ExtractHiEltMI = isExtractHiElt(MI, *MRI)) {
4845+
Mods |= SISrcMods::OP_SEL_0;
4846+
MI = ExtractHiEltMI;
4847+
MO = &MI->getOperand(0);
4848+
Src = MO->getReg();
4849+
4850+
CheckAbsNeg();
4851+
}
4852+
4853+
Matched = true;
4854+
}
4855+
4856+
return {Src, Mods};
4857+
}
4858+
4859+
InstructionSelector::ComplexRendererFns
4860+
AMDGPUInstructionSelector::selectVOP3PMadMixMods(MachineOperand &Root) const {
4861+
Register Src;
4862+
unsigned Mods;
4863+
bool Matched;
4864+
std::tie(Src, Mods) = selectVOP3PMadMixModsImpl(Root, Matched);
4865+
4866+
return {{
4867+
[=](MachineInstrBuilder &MIB) { MIB.addReg(Src); },
4868+
[=](MachineInstrBuilder &MIB) { MIB.addImm(Mods); } // src_mods
4869+
}};
4870+
}
4871+
46824872
void AMDGPUInstructionSelector::renderTruncImm32(MachineInstrBuilder &MIB,
46834873
const MachineInstr &MI,
46844874
int OpIdx) const {

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
9797
bool selectG_UADDO_USUBO_UADDE_USUBE(MachineInstr &I) const;
9898
bool selectG_AMDGPU_MAD_64_32(MachineInstr &I) const;
9999
bool selectG_EXTRACT(MachineInstr &I) const;
100+
bool selectG_FMA_FMAD(MachineInstr &I) const;
100101
bool selectG_MERGE_VALUES(MachineInstr &I) const;
101102
bool selectG_UNMERGE_VALUES(MachineInstr &I) const;
102103
bool selectG_BUILD_VECTOR(MachineInstr &I) const;
@@ -293,6 +294,10 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
293294
ComplexRendererFns selectSMRDBufferImm32(MachineOperand &Root) const;
294295
ComplexRendererFns selectSMRDBufferSgprImm(MachineOperand &Root) const;
295296

297+
std::pair<Register, unsigned> selectVOP3PMadMixModsImpl(MachineOperand &Root,
298+
bool &Matched) const;
299+
ComplexRendererFns selectVOP3PMadMixMods(MachineOperand &Root) const;
300+
296301
void renderTruncImm32(MachineInstrBuilder &MIB, const MachineInstr &MI,
297302
int OpIdx = -1) const;
298303

llvm/lib/Target/AMDGPU/VOP3PInstructions.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
168168
$src1_modifiers, $src1,
169169
$src2_modifiers, $src2,
170170
DSTCLAMP.NONE,
171-
$elt0))
171+
VGPR_32:$elt0))
172172
>;
173173

174174
def : GCNPat <
@@ -181,7 +181,7 @@ multiclass MadFmaMixPats<SDPatternOperator fma_like,
181181
$src1_modifiers, $src1,
182182
$src2_modifiers, $src2,
183183
DSTCLAMP.ENABLE,
184-
$elt0))
184+
VGPR_32:$elt0))
185185
>;
186186

187187
def : GCNPat <

0 commit comments

Comments
 (0)