Skip to content

Commit beceb75

Browse files
committed
[AMDGPU] Codegen for pseudo scalar transcendental instructions
1 parent 12b824f commit beceb75

9 files changed

+940
-30
lines changed

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3781,14 +3781,20 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
37813781
return getDefaultMappingSOP(MI);
37823782
return getDefaultMappingVOP(MI);
37833783
}
3784+
case AMDGPU::G_FSQRT:
3785+
case AMDGPU::G_FEXP2:
3786+
case AMDGPU::G_FLOG2: {
3787+
unsigned Size = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
3788+
if (Subtarget.hasPseudoScalarTrans() && (Size == 16 || Size == 32) &&
3789+
isSALUMapping(MI))
3790+
return getDefaultMappingSOP(MI);
3791+
return getDefaultMappingVOP(MI);
3792+
}
37843793
case AMDGPU::G_SADDSAT: // FIXME: Could lower sat ops for SALU
37853794
case AMDGPU::G_SSUBSAT:
37863795
case AMDGPU::G_UADDSAT:
37873796
case AMDGPU::G_USUBSAT:
37883797
case AMDGPU::G_FMAD:
3789-
case AMDGPU::G_FSQRT:
3790-
case AMDGPU::G_FEXP2:
3791-
case AMDGPU::G_FLOG2:
37923798
case AMDGPU::G_FLDEXP:
37933799
case AMDGPU::G_FMINNUM_IEEE:
37943800
case AMDGPU::G_FMAXNUM_IEEE:
@@ -4253,12 +4259,7 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
42534259
case Intrinsic::amdgcn_sin:
42544260
case Intrinsic::amdgcn_cos:
42554261
case Intrinsic::amdgcn_log_clamp:
4256-
case Intrinsic::amdgcn_log:
4257-
case Intrinsic::amdgcn_exp2:
4258-
case Intrinsic::amdgcn_rcp:
42594262
case Intrinsic::amdgcn_rcp_legacy:
4260-
case Intrinsic::amdgcn_sqrt:
4261-
case Intrinsic::amdgcn_rsq:
42624263
case Intrinsic::amdgcn_rsq_legacy:
42634264
case Intrinsic::amdgcn_rsq_clamp:
42644265
case Intrinsic::amdgcn_fmul_legacy:
@@ -4315,6 +4316,17 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
43154316
case Intrinsic::amdgcn_wmma_i32_16x16x16_iu4:
43164317
case Intrinsic::amdgcn_wmma_i32_16x16x16_iu8:
43174318
return getDefaultMappingVOP(MI);
4319+
case Intrinsic::amdgcn_log:
4320+
case Intrinsic::amdgcn_exp2:
4321+
case Intrinsic::amdgcn_rcp:
4322+
case Intrinsic::amdgcn_rsq:
4323+
case Intrinsic::amdgcn_sqrt: {
4324+
unsigned Size = MRI.getType(MI.getOperand(0).getReg()).getSizeInBits();
4325+
if (Subtarget.hasPseudoScalarTrans() && (Size == 16 || Size == 32) &&
4326+
isSALUMapping(MI))
4327+
return getDefaultMappingSOP(MI);
4328+
return getDefaultMappingVOP(MI);
4329+
}
43184330
case Intrinsic::amdgcn_sbfe:
43194331
case Intrinsic::amdgcn_ubfe:
43204332
if (isSALUMapping(MI))

llvm/lib/Target/AMDGPU/GCNSubtarget.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,8 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,
11611161

11621162
bool hasVGPRSingleUseHintInsts() const { return HasVGPRSingleUseHintInsts; }
11631163

1164+
bool hasPseudoScalarTrans() const { return HasPseudoScalarTrans; }
1165+
11641166
/// Return the maximum number of waves per SIMD for kernels using \p SGPRs
11651167
/// SGPRs
11661168
unsigned getOccupancyWithNumSGPRs(unsigned SGPRs) const;

llvm/lib/Target/AMDGPU/SOPInstructions.td

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -675,19 +675,8 @@ let SubtargetPredicate = isGFX12Plus in {
675675

676676
} // End SubtargetPredicate = isGFX12Plus
677677

678-
def SelectPat : PatFrag <
679-
(ops node:$src1, node:$src2),
680-
(select SCC, $src1, $src2),
681-
[{ return !N->isDivergent(); }]
682-
>;
683-
684678
let Uses = [SCC] in {
685-
let AddedComplexity = 20 in {
686-
def S_CSELECT_B32 : SOP2_32 <"s_cselect_b32",
687-
[(set i32:$sdst, (SelectPat i32:$src0, i32:$src1))]
688-
>;
689-
}
690-
679+
def S_CSELECT_B32 : SOP2_32 <"s_cselect_b32">;
691680
def S_CSELECT_B64 : SOP2_64 <"s_cselect_b64">;
692681
} // End Uses = [SCC]
693682

@@ -1808,6 +1797,27 @@ def : GetFPModePat<fpmode_mask_gfx6plus>;
18081797
// SOP2 Patterns
18091798
//===----------------------------------------------------------------------===//
18101799

1800+
def UniformSelect : PatFrag<
1801+
(ops node:$src0, node:$src1),
1802+
(select SCC, $src0, $src1),
1803+
[{ return !N->isDivergent(); }]
1804+
>;
1805+
1806+
let AddedComplexity = 20 in {
1807+
def : GCNPat<
1808+
(i32 (UniformSelect i32:$src0, i32:$src1)),
1809+
(S_CSELECT_B32 SSrc_b32:$src0, SSrc_b32:$src1)
1810+
>;
1811+
1812+
// TODO: The predicate should not be necessary, but enabling this pattern for
1813+
// all subtargets generates worse code in some cases.
1814+
let OtherPredicates = [HasPseudoScalarTrans] in
1815+
def : GCNPat<
1816+
(f32 (UniformSelect f32:$src0, f32:$src1)),
1817+
(S_CSELECT_B32 SSrc_b32:$src0, SSrc_b32:$src1)
1818+
>;
1819+
}
1820+
18111821
// V_ADD_I32_e32/S_ADD_U32 produces carry in VCC/SCC. For the vector
18121822
// case, the sgpr-copies pass will fix this to use the vector version.
18131823
def : GCNPat <

llvm/lib/Target/AMDGPU/VOP3Instructions.td

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -855,18 +855,34 @@ def VOP_Pseudo_Scalar_F16 : VOP_Pseudo_Scalar<SReg_32_XEXEC, SSrc_f16, f32, f16>
855855

856856
let SubtargetPredicate = HasPseudoScalarTrans, TRANS = 1,
857857
isReMaterializable = 1, SchedRW = [WritePseudoScalarTrans] in {
858-
defm V_S_EXP_F32 : VOP3PseudoScalarInst<"v_s_exp_f32", VOP_Pseudo_Scalar_F32>;
858+
defm V_S_EXP_F32 : VOP3PseudoScalarInst<"v_s_exp_f32", VOP_Pseudo_Scalar_F32, AMDGPUexp>;
859859
defm V_S_EXP_F16 : VOP3PseudoScalarInst<"v_s_exp_f16", VOP_Pseudo_Scalar_F16>;
860-
defm V_S_LOG_F32 : VOP3PseudoScalarInst<"v_s_log_f32", VOP_Pseudo_Scalar_F32>;
860+
defm V_S_LOG_F32 : VOP3PseudoScalarInst<"v_s_log_f32", VOP_Pseudo_Scalar_F32, AMDGPUlog>;
861861
defm V_S_LOG_F16 : VOP3PseudoScalarInst<"v_s_log_f16", VOP_Pseudo_Scalar_F16>;
862-
defm V_S_RCP_F32 : VOP3PseudoScalarInst<"v_s_rcp_f32", VOP_Pseudo_Scalar_F32>;
862+
defm V_S_RCP_F32 : VOP3PseudoScalarInst<"v_s_rcp_f32", VOP_Pseudo_Scalar_F32, AMDGPUrcp>;
863863
defm V_S_RCP_F16 : VOP3PseudoScalarInst<"v_s_rcp_f16", VOP_Pseudo_Scalar_F16>;
864-
defm V_S_RSQ_F32 : VOP3PseudoScalarInst<"v_s_rsq_f32", VOP_Pseudo_Scalar_F32>;
864+
defm V_S_RSQ_F32 : VOP3PseudoScalarInst<"v_s_rsq_f32", VOP_Pseudo_Scalar_F32, AMDGPUrsq>;
865865
defm V_S_RSQ_F16 : VOP3PseudoScalarInst<"v_s_rsq_f16", VOP_Pseudo_Scalar_F16>;
866-
defm V_S_SQRT_F32 : VOP3PseudoScalarInst<"v_s_sqrt_f32", VOP_Pseudo_Scalar_F32>;
866+
defm V_S_SQRT_F32 : VOP3PseudoScalarInst<"v_s_sqrt_f32", VOP_Pseudo_Scalar_F32, any_amdgcn_sqrt>;
867867
defm V_S_SQRT_F16 : VOP3PseudoScalarInst<"v_s_sqrt_f16", VOP_Pseudo_Scalar_F16>;
868868
}
869869

870+
class PseudoScalarPatF16<SDPatternOperator node, VOP3_Pseudo inst> : GCNPat <
871+
(f16 (UniformUnaryFrag<node> (f16 (VOP3Mods0 f16:$src0, i32:$src0_modifiers,
872+
i1:$clamp, i32:$omod)))),
873+
(f16 (COPY_TO_REGCLASS (f32 (inst i32:$src0_modifiers, f16:$src0, i1:$clamp,
874+
i32:$omod)),
875+
SReg_32_XEXEC))
876+
>;
877+
878+
let SubtargetPredicate = HasPseudoScalarTrans in {
879+
def : PseudoScalarPatF16<AMDGPUexpf16, V_S_EXP_F16_e64>;
880+
def : PseudoScalarPatF16<AMDGPUlogf16, V_S_LOG_F16_e64>;
881+
def : PseudoScalarPatF16<AMDGPUrcp, V_S_RCP_F16_e64>;
882+
def : PseudoScalarPatF16<AMDGPUrsq, V_S_RSQ_F16_e64>;
883+
def : PseudoScalarPatF16<any_amdgcn_sqrt, V_S_SQRT_F16_e64>;
884+
}
885+
870886
//===----------------------------------------------------------------------===//
871887
// Integer Clamp Patterns
872888
//===----------------------------------------------------------------------===//

llvm/lib/Target/AMDGPU/VOPInstructions.td

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,8 +1303,17 @@ multiclass VOP3Inst<string OpName, VOPProfile P, SDPatternOperator node = null_f
13031303
} // end SubtargetPredicate = isGFX11Plus
13041304
}
13051305

1306-
multiclass VOP3PseudoScalarInst<string OpName, VOPProfile P> {
1307-
def _e64 : VOP3_Pseudo<OpName, P>;
1306+
class UniformUnaryFragOrOp<SDPatternOperator Op> {
1307+
SDPatternOperator ret = !if(!or(!isa<SDNode>(Op), !isa<PatFrags>(Op)),
1308+
UniformUnaryFrag<Op>, Op);
1309+
}
1310+
1311+
multiclass VOP3PseudoScalarInst<string OpName, VOPProfile P,
1312+
SDPatternOperator node = null_frag> {
1313+
def _e64 : VOP3_Pseudo<OpName, P, [(set P.DstVT:$vdst,
1314+
(UniformUnaryFragOrOp<node>.ret
1315+
(P.Src0VT (VOP3Mods0 P.Src0VT:$src0, i32:$src0_modifiers, i1:$clamp,
1316+
i32:$omod))))]>;
13081317
}
13091318

13101319
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)