Skip to content

Commit 7c2ebe5

Browse files
authored
AMDGPU: Restrict src0 to VGPRs only for certain cvt scale opcodes. (llvm#127464)
The Src0 operand width higher that 32-bits of cvt_scale opcodes operating on FP6/BF6/FP4 need to be restricted to take only VGPRs.
1 parent 776cdda commit 7c2ebe5

File tree

6 files changed

+1664
-16
lines changed

6 files changed

+1664
-16
lines changed

llvm/lib/Target/AMDGPU/SIInstrInfo.td

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,19 +1803,16 @@ class getVOP3SrcForVT<ValueType VT, bit IsTrue16 = 0> {
18031803
1 : VSrc_b32);
18041804
}
18051805

1806-
// Returns the vreg register class to use for sources of VOP3 instructions for the
1807-
// given VT.
1808-
class getVOP3VRegSrcForVT<ValueType VT, bit IsTrue16 = 0, bit IsFake16 = 0> {
1809-
RegisterOperand ret =
1810-
!cond(!eq(VT.Size, 128) : RegisterOperand<VReg_128>,
1811-
!eq(VT.Size, 96) : RegisterOperand<VReg_96>,
1812-
!eq(VT.Size, 64) : RegisterOperand<VReg_64>,
1813-
!eq(VT.Size, 48) : RegisterOperand<VReg_64>,
1814-
!eq(VT.Size, 16) : !if(IsTrue16,
1815-
!if(IsFake16, RegisterOperand<VGPR_32>,
1816-
RegisterOperand<VGPR_16>),
1817-
RegisterOperand<VGPR_32>),
1818-
1 : RegisterOperand<VGPR_32>);
1806+
// VGPR only VOP3 src with 9 bit encoding
1807+
class getVOP3VRegSrcForVT<ValueType VT> {
1808+
RegisterOperand ret = !cond(!eq(VT.Size, 1024) : VRegSrc_1024,
1809+
!eq(VT.Size, 512) : VRegSrc_512,
1810+
!eq(VT.Size, 256) : VRegSrc_256,
1811+
!eq(VT.Size, 192) : VRegSrc_192,
1812+
!eq(VT.Size, 128) : VRegSrc_128,
1813+
!eq(VT.Size, 96) : VRegSrc_96,
1814+
!eq(VT.Size, 64) : VRegSrc_64,
1815+
1 : VRegSrc_32);
18191816
}
18201817

18211818
// Src2 of VOP3 DPP instructions cannot be a literal
@@ -2859,6 +2856,7 @@ def VOP_V2I16_F32_F32_F32 : VOPProfile<[v2i16, f32, f32, f32]>;
28592856
def VOP_V2I16_V2F16_F32 : VOPProfile<[v2i16, v2f16, f32, untyped]>;
28602857
def VOP_V2I16_V2BF16_F32 : VOPProfile<[v2i16, v2bf16, f32, untyped]>;
28612858
def VOP_I32_F32_F32_F32 : VOPProfile<[i32, f32, f32, f32]>;
2859+
def VOP_I32_V2F32_I32_F32 : VOPProfile<[i32, v2f32, i32, f32]>;
28622860
def VOP_I32_V2F16_F32_F32 : VOPProfile<[i32, v2f16, f32, f32]>;
28632861
def VOP_I32_V2BF16_F32_F32: VOPProfile<[i32, v2bf16, f32, f32]>;
28642862
def VOP_BF16_F32_I32 : VOPProfile<[bf16, f32, i32, untyped]>;

llvm/lib/Target/AMDGPU/VOP2Instructions.td

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,12 +418,27 @@ def VOP_MADMK_F16_fake16 : VOP_MADMK <f16> {
418418
}
419419
def VOP_MADMK_F32 : VOP_MADMK <f32>;
420420

421+
// Returns the vreg register class to use for sources of VOP3 instructions for the
422+
// given VT.
423+
class getVOP3VRegForVT<ValueType VT, bit IsTrue16 = 0, bit IsFake16 = 0> {
424+
RegisterOperand ret =
425+
!cond(!eq(VT.Size, 128) : RegisterOperand<VReg_128>,
426+
!eq(VT.Size, 96) : RegisterOperand<VReg_96>,
427+
!eq(VT.Size, 64) : RegisterOperand<VReg_64>,
428+
!eq(VT.Size, 48) : RegisterOperand<VReg_64>,
429+
!eq(VT.Size, 16) : !if(IsTrue16,
430+
!if(IsFake16, RegisterOperand<VGPR_32>,
431+
RegisterOperand<VGPR_16>),
432+
RegisterOperand<VGPR_32>),
433+
1 : RegisterOperand<VGPR_32>);
434+
}
435+
421436
// FIXME: Remove src2_modifiers. It isn't used, so is wasting memory
422437
// and processing time but it makes it easier to convert to mad.
423438
class VOP_MAC <ValueType vt0, ValueType vt1=vt0> : VOPProfile <[vt0, vt1, vt1, vt0]> {
424439
let Ins32 = (ins Src0RC32:$src0, Src1RC32:$src1, getVregSrcForVT<Src2VT>.ret:$src2);
425440
// Src2 must accept the same operand types as vdst, namely VGPRs only
426-
let Src2RC64 = getVOP3VRegSrcForVT<Src2VT, IsTrue16, !not(IsRealTrue16)>.ret;
441+
let Src2RC64 = getVOP3VRegForVT<Src2VT, IsTrue16, !not(IsRealTrue16)>.ret;
427442
let Ins64 = getIns64<Src0RC64, Src1RC64, Src2RC64, 3,
428443
0, HasModifiers, HasModifiers, HasOMod,
429444
Src0Mod, Src1Mod, Src2Mod>.ret;

llvm/lib/Target/AMDGPU/VOP3Instructions.td

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,11 @@ class VOP3_CVT_SCALE_SR_PK_F4_F16BF16_TiedInput_Profile<ValueType Src0Ty> :
10521052
let HasFP4DstByteSel = 1;
10531053
}
10541054

1055-
def VOP3_CVT_SCALE_SR_PK_F4_F32_TiedInput_Profile : VOP3_Profile<VOPProfile<[i32, v2f32, i32, f32]>, VOP3_OPSEL> {
1055+
class VOP3_CVT_SCALE_SR_PK_F4_F32_TiedInput_Profile<VOPProfile P>
1056+
: VOP3_Profile<P, VOP3_OPSEL> {
1057+
1058+
let Src0RC64 = !if(!gt(P.Src0VT.Size, 32), getVOP3VRegSrcForVT<P.Src0VT>.ret,
1059+
getVOP3SrcForVT<P.Src0VT>.ret);
10561060
let InsVOP3OpSel = (ins PackedF32InputMods: $src0_modifiers, Src0RC64:$src0,
10571061
Int32InputMods: $src1_modifiers, Src1RC64:$src1,
10581062
FP32InputMods: $src2_modifiers, Src2RC64:$src2,
@@ -1100,6 +1104,11 @@ class VOP3_CVT_SCALEF32_PK_F864_Profile<VOPProfile P> : VOP3_Profile<P> {
11001104
let HasExt32BitDPP = 0;
11011105
let HasExtVOP3DPP = 0;
11021106
let HasExt64BitDPP = 0;
1107+
1108+
// All convert opcodes operating on FP6/BF6/FP4 data must use VGPR sources for
1109+
// any operand slots > 32 bit.
1110+
let Src0RC64 = !if(!gt(P.Src0VT.Size, 32), getVOP3VRegSrcForVT<P.Src0VT>.ret,
1111+
getVOP3SrcForVT<P.Src0VT>.ret);
11031112
}
11041113

11051114
let SubtargetPredicate = HasFP8ConversionScaleInsts, mayRaiseFPException = 0 in {
@@ -1141,7 +1150,10 @@ let SubtargetPredicate = HasFP4ConversionScaleInsts, mayRaiseFPException = 0 in
11411150
let Constraints = "@earlyclobber $vdst" in {
11421151
defm V_CVT_SCALEF32_SR_PK_FP4_F16: VOP3Inst<"v_cvt_scalef32_sr_pk_fp4_f16", VOP3_CVT_SCALE_SR_PK_F4_F16BF16_TiedInput_Profile<v2f16>>;
11431152
defm V_CVT_SCALEF32_SR_PK_FP4_BF16: VOP3Inst<"v_cvt_scalef32_sr_pk_fp4_bf16", VOP3_CVT_SCALE_SR_PK_F4_F16BF16_TiedInput_Profile<v2bf16>>;
1144-
defm V_CVT_SCALEF32_SR_PK_FP4_F32: VOP3Inst<"v_cvt_scalef32_sr_pk_fp4_f32", VOP3_CVT_SCALE_SR_PK_F4_F32_TiedInput_Profile>;
1153+
defm V_CVT_SCALEF32_SR_PK_FP4_F32
1154+
: VOP3Inst<"v_cvt_scalef32_sr_pk_fp4_f32",
1155+
VOP3_CVT_SCALE_SR_PK_F4_F32_TiedInput_Profile<
1156+
VOP_I32_V2F32_I32_F32>>;
11451157
}
11461158
}
11471159
defm V_CVT_SCALEF32_PK_F16_FP4 : VOP3Inst<"v_cvt_scalef32_pk_f16_fp4", VOP3_CVT_SCALE_PK_F16BF16F32_FP4FP8BF8_Profile<v2f16>>;

0 commit comments

Comments
 (0)