Skip to content

Commit 63f7ba9

Browse files
committed
AMDGPU: Define v_mfma_f32_{16x16x128|32x32x64}_f8f6f4 instructions
These use a new VOP3PX encoding for the v_mfma_scale_* instructions, which bundles the pre-scale v_mfma_ld_scale_b32. None of the modifiers are supported yet (op_sel, neg or clamp). I'm not sure the intrinsic should really expose op_sel (or any of the others). If I'm reading the documentation correctly, we should be able to just have the raw scale operands and auto-match op_sel to byte extract patterns. The op_sel syntax also seems extra horrible in this usage, especially with the usual assumed op_sel_hi=-1 behavior.
1 parent a5454b4 commit 63f7ba9

23 files changed

+4739
-22
lines changed

llvm/docs/AMDGPUUsage.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,19 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
13971397
used by hardware to control active lanes when used in EXEC register.
13981398
For example, ballot(i1 true) return EXEC mask.
13991399

1400+
llvm.amdgcn.mfma.f32.16x16x128.f8f6f4.scaled Emit `v_mfma_f32_16x16x128_f8f6f4`, bundled with a `v_mfma_ld_scale_b32`
1401+
to set the scale factor. The last 4 operands correspond to the inputs
1402+
to `v_mfma_ld_scale_b32`:
1403+
1404+
llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4 Emit `v_mfma_scale_f32_16x16x128_f8f6f4` to set the scale factor. The
1405+
last 4 operands correspond to the scale inputs.
1406+
2-bit byte index to use for each lane for matrix A
1407+
Matrix A scale values
1408+
2-bit byte index to use for each lane for matrix B
1409+
Matrix B scale values
1410+
1411+
llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4 Emit `v_mfma_scale_f32_32x32x64_f8f6f4`
1412+
14001413
============================================== ==========================================================
14011414

14021415
.. TODO::

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2968,6 +2968,27 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
29682968
[IntrConvergent, IntrNoMem,
29692969
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>]>;
29702970

2971+
2972+
class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
2973+
ClangBuiltin<!subst("int", "__builtin", NAME)>,
2974+
DefaultAttrsIntrinsic<[DestTy],
2975+
[SrcABTy, SrcABTy, DestTy,
2976+
llvm_i32_ty, // cbsz
2977+
llvm_i32_ty, // abid
2978+
llvm_i32_ty, // blgp
2979+
// llvm_i1_ty, // TODO: neg_src2
2980+
// llvm_i1_ty, // TODO: abs_src2
2981+
// llvm_i1_ty, // TODO: clamp
2982+
llvm_i32_ty, // op_sel (A matrix scale, 2-bits) // TODO: Make i2?
2983+
llvm_i32_ty, // v_mfma_ld_scale_b32 src0 (A matrix scale)
2984+
llvm_i32_ty, // op_sel (B matrix scale, 2-bits) // TODO: Make i2?
2985+
llvm_i32_ty // v_mfma_ld_scale_b32 src1 (B matrix scale)
2986+
],
2987+
[IntrConvergent, IntrNoMem,
2988+
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>,
2989+
ImmArg<ArgIndex<6>>, ImmArg<ArgIndex<8>>
2990+
]>;
2991+
29712992
defset list<Intrinsic> AMDGPUMFMAIntrinsics908 = {
29722993
def int_amdgcn_mfma_f32_32x32x1f32 : AMDGPUMfmaIntrinsic<llvm_v32f32_ty, llvm_float_ty>;
29732994
def int_amdgcn_mfma_f32_16x16x1f32 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_float_ty>;
@@ -3119,6 +3140,8 @@ def int_amdgcn_mfma_f32_16x16x32_f16 : AMDGPUMfmaIntrinsic<llvm_v4f32_ty, llvm_v
31193140
def int_amdgcn_mfma_f32_32x32x16_f16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8f16_ty>;
31203141

31213142
def int_amdgcn_mfma_f32_32x32x16_bf16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8bf16_ty>;
3143+
def int_amdgcn_mfma_scale_f32_16x16x128_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v4f32_ty, llvm_v8i32_ty>;
3144+
def int_amdgcn_mfma_scale_f32_32x32x64_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v16f32_ty, llvm_v8i32_ty>;
31223145
}
31233146

31243147
//===----------------------------------------------------------------------===//

llvm/lib/Target/AMDGPU/AMDGPUGISel.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,3 +423,6 @@ def gi_fp_pow2_to_exponent : GICustomOperandRenderer<"renderFPPow2ToExponent">,
423423

424424
def gi_as_hw_round_mode : GICustomOperandRenderer<"renderRoundMode">,
425425
GISDNodeXFormEquiv<as_hw_round_mode>;
426+
427+
def gi_MFMALdScaleModifierOp : GICustomOperandRenderer<"renderScaledMAIIntrinsicOperand">,
428+
GISDNodeXFormEquiv<MFMALdScaleXForm>;

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5737,6 +5737,18 @@ void AMDGPUInstructionSelector::renderRoundMode(MachineInstrBuilder &MIB,
57375737
MIB.addImm((MI.getOperand(OpIdx).getImm() + 3) % 4);
57385738
}
57395739

5740+
/// Convert from 2-bit value to enum values used for op_sel* source modifiers.
5741+
void AMDGPUInstructionSelector::renderScaledMAIIntrinsicOperand(
5742+
MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const {
5743+
unsigned Val = MI.getOperand(OpIdx).getImm();
5744+
unsigned New = 0;
5745+
if (Val & 0x1)
5746+
New |= SISrcMods::OP_SEL_0;
5747+
if (Val & 0x2)
5748+
New |= SISrcMods::OP_SEL_1;
5749+
MIB.addImm(New);
5750+
}
5751+
57405752
bool AMDGPUInstructionSelector::isInlineImmediate(const APInt &Imm) const {
57415753
return TII.isInlineConstant(Imm);
57425754
}

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
364364

365365
void renderRoundMode(MachineInstrBuilder &MIB, const MachineInstr &MI,
366366
int OpIdx) const;
367+
void renderScaledMAIIntrinsicOperand(MachineInstrBuilder &MIB,
368+
const MachineInstr &MI, int OpIdx) const;
367369

368370
bool isInlineImmediate(const APInt &Imm) const;
369371
bool isInlineImmediate(const APFloat &Imm) const;

llvm/lib/Target/AMDGPU/AMDGPUInstructions.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class AMDGPUInst <dag outs, dag ins, string asm = "",
4040
// instructions to not match without killing the whole decode process. It is
4141
// mainly used for ARM, but Tablegen expects this field to exist or it fails
4242
// to build the decode table.
43-
field bits<96> SoftFail = 0;
43+
field bits<128> SoftFail = 0; // FIXME: If this is smaller than largest instruction, DecodeEmitter crashes
4444

4545
let DecoderNamespace = Namespace;
4646

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4769,6 +4769,25 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
47694769
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
47704770
break;
47714771
}
4772+
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
4773+
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
4774+
const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
4775+
OpdsMapping[0] =
4776+
Info->mayNeedAGPRs()
4777+
? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
4778+
: getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
4779+
4780+
OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
4781+
OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
4782+
OpdsMapping[4] =
4783+
Info->mayNeedAGPRs()
4784+
? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
4785+
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
4786+
4787+
OpdsMapping[9] = getVGPROpMapping(MI.getOperand(9).getReg(), MRI, *TRI);
4788+
OpdsMapping[11] = getVGPROpMapping(MI.getOperand(11).getReg(), MRI, *TRI);
4789+
break;
4790+
}
47724791
case Intrinsic::amdgcn_smfmac_f32_16x16x32_f16:
47734792
case Intrinsic::amdgcn_smfmac_f32_32x32x16_f16:
47744793
case Intrinsic::amdgcn_smfmac_f32_16x16x32_bf16:

llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,17 @@ static inline DecoderUInt128 eat12Bytes(ArrayRef<uint8_t> &Bytes) {
493493
return DecoderUInt128(Lo, Hi);
494494
}
495495

496+
static inline DecoderUInt128 eat16Bytes(ArrayRef<uint8_t> &Bytes) {
497+
assert(Bytes.size() >= 16);
498+
uint64_t Lo =
499+
support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
500+
Bytes = Bytes.slice(8);
501+
uint64_t Hi =
502+
support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
503+
Bytes = Bytes.slice(8);
504+
return DecoderUInt128(Lo, Hi);
505+
}
506+
496507
DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
497508
ArrayRef<uint8_t> Bytes_,
498509
uint64_t Address,
@@ -529,6 +540,15 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
529540

530541
// Reinitialize Bytes
531542
Bytes = Bytes_.slice(0, MaxInstBytesNum);
543+
544+
} else if (Bytes.size() >= 16 &&
545+
STI.hasFeature(AMDGPU::FeatureGFX950Insts)) {
546+
DecoderUInt128 DecW = eat16Bytes(Bytes);
547+
if (tryDecodeInst(DecoderTableGFX940128, MI, DecW, Address, CS))
548+
break;
549+
550+
// Reinitialize Bytes
551+
Bytes = Bytes_.slice(0, MaxInstBytesNum);
532552
}
533553

534554
if (Bytes.size() >= 8) {

llvm/lib/Target/AMDGPU/MCTargetDesc/AMDGPUMCAsmInfo.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ unsigned AMDGPUMCAsmInfo::getMaxInstLength(const MCSubtargetInfo *STI) const {
5959
if (STI->hasFeature(AMDGPU::FeatureNSAEncoding))
6060
return 20;
6161

62+
// VOP3PX encoding.
63+
if (STI->hasFeature(AMDGPU::FeatureGFX950Insts))
64+
return 16;
65+
6266
// 64-bit instruction with 32-bit literal.
6367
if (STI->hasFeature(AMDGPU::FeatureVOP3Literal))
6468
return 12;

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15449,6 +15449,23 @@ void SITargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
1544915449
MRI.setRegClass(Op.getReg(), NewRC);
1545015450
}
1545115451

15452+
if (TII->isMAI(MI)) {
15453+
// The ordinary src0, src1, src2 were legalized above.
15454+
//
15455+
// We have to also legalize the appended v_mfma_ld_scale_b32 operands,
15456+
// as a separate instruction.
15457+
int Src0Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
15458+
AMDGPU::OpName::scale_src0);
15459+
if (Src0Idx != -1) {
15460+
int Src1Idx = Src0Idx + 2;
15461+
assert(Src1Idx = AMDGPU::getNamedOperandIdx(
15462+
MI.getOpcode(), AMDGPU::OpName::scale_src1));
15463+
if (TII->usesConstantBus(MRI, MI, Src0Idx) &&
15464+
TII->usesConstantBus(MRI, MI, Src1Idx))
15465+
TII->legalizeOpWithMove(MI, Src1Idx);
15466+
}
15467+
}
15468+
1545215469
if (!HasAGPRs)
1545315470
return;
1545415471

llvm/lib/Target/AMDGPU/SIInstrFormats.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,11 @@ class Enc96 {
300300
int Size = 12;
301301
}
302302

303+
class Enc128 {
304+
field bits<128> Inst;
305+
int Size = 16;
306+
}
307+
303308
def CPolBit {
304309
int GLC = 0;
305310
int SLC = 1;

llvm/lib/Target/AMDGPU/SIInstrInfo.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,12 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {
11151115
const MachineOperand &MO,
11161116
const MCOperandInfo &OpInfo) const;
11171117

1118+
bool usesConstantBus(const MachineRegisterInfo &MRI, const MachineInstr &MI,
1119+
int OpIdx) const {
1120+
return usesConstantBus(MRI, MI.getOperand(OpIdx),
1121+
MI.getDesc().operands()[OpIdx]);
1122+
}
1123+
11181124
/// Return true if this instruction has any modifiers.
11191125
/// e.g. src[012]_mod, omod, clamp.
11201126
bool hasModifiers(unsigned Opcode) const;

llvm/lib/Target/AMDGPU/SIInstrInfo.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,16 @@ def fp16_zeros_high_16bits : PatLeaf<(f16 VGPR_32:$src), [{
914914
return fp16SrcZerosHighBits(N->getOpcode());
915915
}]>;
916916

917+
def MFMALdScaleXForm : SDNodeXForm<timm, [{
918+
unsigned Val = N->getZExtValue();
919+
unsigned New = 0;
920+
if (Val & 0x1)
921+
New |= SISrcMods::OP_SEL_0;
922+
if (Val & 0x2)
923+
New |= SISrcMods::OP_SEL_1;
924+
return CurDAG->getTargetConstant(New, SDLoc(N), MVT::i32);
925+
}]>;
926+
917927
def is_canonicalized : PatLeaf<(fAny srcvalue:$src), [{
918928
const SITargetLowering &Lowering =
919929
*static_cast<const SITargetLowering *>(getTargetLowering());
@@ -1515,6 +1525,10 @@ class PackedIntInputMods <PackedIntInputModsMatchClass matchClass> : InputMods <
15151525
def PackedF16InputMods : PackedFPInputMods<PackedF16InputModsMatchClass>;
15161526
def PackedI16InputMods : PackedIntInputMods<PackedI16InputModsMatchClass>;
15171527

1528+
def MFMALdScaleModifierOp : TImmLeaf<i32, [{
1529+
return isUInt<2>(Imm);
1530+
}], MFMALdScaleXForm>;
1531+
15181532
//===----------------------------------------------------------------------===//
15191533
// Complex patterns
15201534
//===----------------------------------------------------------------------===//
@@ -2851,6 +2865,8 @@ def VOP_V16F32_V2I32_V4I32_I32 : VOPProfile <[v16f32, v2i32, v4i32, i32]>;
28512865
def VOP_V4F32_V8F16_V8F16_V4F32 : VOPProfile <[v4f32, v8f16, v8f16, v4f32]>;
28522866
def VOP_V16F32_V8F16_V8F16_V16F32 : VOPProfile <[v16f32, v8f16, v8f16, v16f32]>;
28532867
def VOP_V16F32_V8BF16_V8BF16_V16F32 : VOPProfile <[v16f32, v8bf16, v8bf16, v16f32]>;
2868+
def VOP_V4F32_V8I32_V8I32_V4F32 : VOPProfile <[v4f32, v8i32, v8i32, v4f32]>;
2869+
def VOP_V16F32_V8I32_V8I32_V16F32 : VOPProfile <[v16f32, v8i32, v8i32, v16f32]>;
28542870

28552871

28562872
class Commutable_REV <string revOp, bit isOrig> {

llvm/lib/Target/AMDGPU/SIRegisterInfo.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,11 +1338,13 @@ class AVSrcOperand<RegisterClass regClass, string width>
13381338
def AVSrc_32 : AVSrcOperand<AV_32, "OPW32">;
13391339
def AVSrc_64 : AVSrcOperand<AV_64, "OPW64">;
13401340
def AVSrc_128 : AVSrcOperand<AV_128, "OPW128">;
1341+
def AVSrc_256 : AVSrcOperand<AV_256, "OPW256">;
13411342

13421343
class AVDstOperand<RegisterClass regClass, string width>
13431344
: AVOperand<regClass, "decodeAV10", width>;
13441345

13451346
def AVDst_128 : AVDstOperand<AV_128, "OPW128">;
1347+
def AVDst_256 : AVDstOperand<AV_256, "OPW256">;
13461348
def AVDst_512 : AVDstOperand<AV_512, "OPW512">;
13471349

13481350
class AVLdStOperand<RegisterClass regClass, string width>

0 commit comments

Comments
 (0)