@@ -522,6 +522,60 @@ bool AMDGPUInstructionSelector::selectG_EXTRACT(MachineInstr &I) const {
522
522
return true ;
523
523
}
524
524
525
+ bool AMDGPUInstructionSelector::selectG_FMA_FMAD (MachineInstr &I) const {
526
+ assert (I.getOpcode () == AMDGPU::G_FMA || I.getOpcode () == AMDGPU::G_FMAD);
527
+
528
+ // Try to manually select MAD_MIX/FMA_MIX.
529
+ Register Dst = I.getOperand (0 ).getReg ();
530
+ LLT ResultTy = MRI->getType (Dst);
531
+ bool IsFMA = I.getOpcode () == AMDGPU::G_FMA;
532
+ if (ResultTy != LLT::scalar (32 ) ||
533
+ (IsFMA ? !Subtarget->hasFmaMixInsts () : !Subtarget->hasMadMixInsts ()))
534
+ return false ;
535
+
536
+ // Avoid using v_mad_mix_f32/v_fma_mix_f32 unless there is actually an operand
537
+ // using the conversion from f16.
538
+ bool MatchedSrc0, MatchedSrc1, MatchedSrc2;
539
+ auto [Src0, Src0Mods] =
540
+ selectVOP3PMadMixModsImpl (I.getOperand (1 ), MatchedSrc0);
541
+ auto [Src1, Src1Mods] =
542
+ selectVOP3PMadMixModsImpl (I.getOperand (2 ), MatchedSrc1);
543
+ auto [Src2, Src2Mods] =
544
+ selectVOP3PMadMixModsImpl (I.getOperand (3 ), MatchedSrc2);
545
+
546
+ #ifndef NDEBUG
547
+ const SIMachineFunctionInfo *MFI =
548
+ I.getMF ()->getInfo <SIMachineFunctionInfo>();
549
+ AMDGPU::SIModeRegisterDefaults Mode = MFI->getMode ();
550
+ assert ((IsFMA || !Mode.allFP32Denormals ()) &&
551
+ " fmad selected with denormals enabled" );
552
+ #endif
553
+
554
+ // TODO: We can select this with f32 denormals enabled if all the sources are
555
+ // converted from f16 (in which case fmad isn't legal).
556
+ if (!MatchedSrc0 && !MatchedSrc1 && !MatchedSrc2)
557
+ return false ;
558
+
559
+ const unsigned OpC = IsFMA ? AMDGPU::V_FMA_MIX_F32 : AMDGPU::V_MAD_MIX_F32;
560
+ MachineInstr *MixInst =
561
+ BuildMI (*I.getParent (), I, I.getDebugLoc (), TII.get (OpC), Dst)
562
+ .addImm (Src0Mods)
563
+ .addReg (Src0)
564
+ .addImm (Src1Mods)
565
+ .addReg (Src1)
566
+ .addImm (Src2Mods)
567
+ .addReg (Src2)
568
+ .addImm (0 )
569
+ .addImm (0 )
570
+ .addImm (0 );
571
+
572
+ if (!constrainSelectedInstRegOperands (*MixInst, TII, TRI, RBI))
573
+ return false ;
574
+
575
+ I.eraseFromParent ();
576
+ return true ;
577
+ }
578
+
525
579
bool AMDGPUInstructionSelector::selectG_MERGE_VALUES (MachineInstr &MI) const {
526
580
MachineBasicBlock *BB = MI.getParent ();
527
581
Register DstReg = MI.getOperand (0 ).getReg ();
@@ -3228,6 +3282,11 @@ bool AMDGPUInstructionSelector::select(MachineInstr &I) {
3228
3282
return selectG_FABS (I);
3229
3283
case TargetOpcode::G_EXTRACT:
3230
3284
return selectG_EXTRACT (I);
3285
+ case TargetOpcode::G_FMA:
3286
+ case TargetOpcode::G_FMAD:
3287
+ if (selectG_FMA_FMAD (I))
3288
+ return true ;
3289
+ return selectImpl (I, *CoverageInfo);
3231
3290
case TargetOpcode::G_MERGE_VALUES:
3232
3291
case TargetOpcode::G_CONCAT_VECTORS:
3233
3292
return selectG_MERGE_VALUES (I);
@@ -4679,6 +4738,137 @@ AMDGPUInstructionSelector::selectSMRDBufferSgprImm(MachineOperand &Root) const {
4679
4738
[=](MachineInstrBuilder &MIB) { MIB.addImm (*EncodedOffset); }}};
4680
4739
}
4681
4740
4741
+ // Variant of stripBitCast that returns the instruction instead of a
4742
+ // MachineOperand.
4743
+ static MachineInstr *stripBitCast (MachineInstr *MI, MachineRegisterInfo &MRI) {
4744
+ if (MI->getOpcode () == AMDGPU::G_BITCAST)
4745
+ return getDefIgnoringCopies (MI->getOperand (1 ).getReg (), MRI);
4746
+ return MI;
4747
+ }
4748
+
4749
+ // Figure out if this is really an extract of the high 16-bits of a dword,
4750
+ // returns nullptr if it isn't.
4751
+ static MachineInstr *isExtractHiElt (MachineInstr *Inst,
4752
+ MachineRegisterInfo &MRI) {
4753
+ Inst = stripBitCast (Inst, MRI);
4754
+
4755
+ if (Inst->getOpcode () != AMDGPU::G_TRUNC)
4756
+ return nullptr ;
4757
+
4758
+ MachineInstr *TruncOp =
4759
+ getDefIgnoringCopies (Inst->getOperand (1 ).getReg (), MRI);
4760
+ TruncOp = stripBitCast (TruncOp, MRI);
4761
+
4762
+ // G_LSHR x, (G_CONSTANT i32 16)
4763
+ if (TruncOp->getOpcode () == AMDGPU::G_LSHR) {
4764
+ auto SrlAmount = getIConstantVRegValWithLookThrough (
4765
+ TruncOp->getOperand (2 ).getReg (), MRI);
4766
+ if (SrlAmount && SrlAmount->Value .getZExtValue () == 16 ) {
4767
+ MachineInstr *SrlOp =
4768
+ getDefIgnoringCopies (TruncOp->getOperand (1 ).getReg (), MRI);
4769
+ return stripBitCast (SrlOp, MRI);
4770
+ }
4771
+ }
4772
+
4773
+ // G_SHUFFLE_VECTOR x, y, shufflemask(1, 1|0)
4774
+ // 1, 0 swaps the low/high 16 bits.
4775
+ // 1, 1 sets the high 16 bits to be the same as the low 16.
4776
+ // in any case, it selects the high elts.
4777
+ if (TruncOp->getOpcode () == AMDGPU::G_SHUFFLE_VECTOR) {
4778
+ assert (MRI.getType (TruncOp->getOperand (0 ).getReg ()) ==
4779
+ LLT::fixed_vector (2 , 16 ));
4780
+
4781
+ ArrayRef<int > Mask = TruncOp->getOperand (3 ).getShuffleMask ();
4782
+ assert (Mask.size () == 2 );
4783
+
4784
+ if (Mask[0 ] == 1 && Mask[1 ] <= 1 ) {
4785
+ MachineInstr *LHS =
4786
+ getDefIgnoringCopies (TruncOp->getOperand (1 ).getReg (), MRI);
4787
+ return stripBitCast (LHS, MRI);
4788
+ }
4789
+ }
4790
+
4791
+ return nullptr ;
4792
+ }
4793
+
4794
+ std::pair<Register, unsigned >
4795
+ AMDGPUInstructionSelector::selectVOP3PMadMixModsImpl (MachineOperand &Root,
4796
+ bool &Matched) const {
4797
+ Matched = false ;
4798
+
4799
+ Register Src;
4800
+ unsigned Mods;
4801
+ std::tie (Src, Mods) = selectVOP3ModsImpl (Root);
4802
+
4803
+ MachineInstr *MI = getDefIgnoringCopies (Src, *MRI);
4804
+ if (MI->getOpcode () == AMDGPU::G_FPEXT) {
4805
+ MachineOperand *MO = &MI->getOperand (1 );
4806
+ Src = MO->getReg ();
4807
+ MI = getDefIgnoringCopies (Src, *MRI);
4808
+
4809
+ assert (MRI->getType (Src) == LLT::scalar (16 ));
4810
+
4811
+ // See through bitcasts.
4812
+ // FIXME: Would be nice to use stripBitCast here.
4813
+ if (MI->getOpcode () == AMDGPU::G_BITCAST) {
4814
+ MO = &MI->getOperand (1 );
4815
+ Src = MO->getReg ();
4816
+ MI = getDefIgnoringCopies (Src, *MRI);
4817
+ }
4818
+
4819
+ const auto CheckAbsNeg = [&]() {
4820
+ // Be careful about folding modifiers if we already have an abs. fneg is
4821
+ // applied last, so we don't want to apply an earlier fneg.
4822
+ if ((Mods & SISrcMods::ABS) == 0 ) {
4823
+ unsigned ModsTmp;
4824
+ std::tie (Src, ModsTmp) = selectVOP3ModsImpl (*MO);
4825
+ MI = getDefIgnoringCopies (Src, *MRI);
4826
+
4827
+ if ((ModsTmp & SISrcMods::NEG) != 0 )
4828
+ Mods ^= SISrcMods::NEG;
4829
+
4830
+ if ((ModsTmp & SISrcMods::ABS) != 0 )
4831
+ Mods |= SISrcMods::ABS;
4832
+ }
4833
+ };
4834
+
4835
+ CheckAbsNeg ();
4836
+
4837
+ // op_sel/op_sel_hi decide the source type and source.
4838
+ // If the source's op_sel_hi is set, it indicates to do a conversion from
4839
+ // fp16. If the sources's op_sel is set, it picks the high half of the
4840
+ // source register.
4841
+
4842
+ Mods |= SISrcMods::OP_SEL_1;
4843
+
4844
+ if (MachineInstr *ExtractHiEltMI = isExtractHiElt (MI, *MRI)) {
4845
+ Mods |= SISrcMods::OP_SEL_0;
4846
+ MI = ExtractHiEltMI;
4847
+ MO = &MI->getOperand (0 );
4848
+ Src = MO->getReg ();
4849
+
4850
+ CheckAbsNeg ();
4851
+ }
4852
+
4853
+ Matched = true ;
4854
+ }
4855
+
4856
+ return {Src, Mods};
4857
+ }
4858
+
4859
+ InstructionSelector::ComplexRendererFns
4860
+ AMDGPUInstructionSelector::selectVOP3PMadMixMods (MachineOperand &Root) const {
4861
+ Register Src;
4862
+ unsigned Mods;
4863
+ bool Matched;
4864
+ std::tie (Src, Mods) = selectVOP3PMadMixModsImpl (Root, Matched);
4865
+
4866
+ return {{
4867
+ [=](MachineInstrBuilder &MIB) { MIB.addReg (Src); },
4868
+ [=](MachineInstrBuilder &MIB) { MIB.addImm (Mods); } // src_mods
4869
+ }};
4870
+ }
4871
+
4682
4872
void AMDGPUInstructionSelector::renderTruncImm32 (MachineInstrBuilder &MIB,
4683
4873
const MachineInstr &MI,
4684
4874
int OpIdx) const {
0 commit comments