Skip to content

Commit 7efa846

Browse files
committed
AMDGPU: Define v_mfma_f32_{16x16x128|32x32x64}_f8f6f4 instructions
These use a new VOP3PX encoding for the v_mfma_scale_* instructions, which bundles the pre-scale v_mfma_ld_scale_b32. None of the modifiers are supported yet (op_sel, neg or clamp). I'm not sure the intrinsic should really expose op_sel (or any of the others). If I'm reading the documentation correctly, we should be able to just have the raw scale operands and auto-match op_sel to byte extract patterns. The op_sel syntax also seems extra horrible in this usage, especially with the usual assumed op_sel_hi=-1 behavior. The f8f6f4 intrinsics allow using different vector types, corresponding to the 3 different format widths. These can use 4, 6 or 8 x i32 vectors depending on if the format is fp4, fp6/bf6, or fp8/bf8. Verification that the used format matches the vector type will come later. This requires defining a separate pseudoinstruction for each register width combination, so 9 pseudos per opcode. This makes disassembly ambiguous, since now the opcode to use depends on the operand. Handle this by only defining the _f8_f8 variant as a real instruction, and the disassembler manually adjusts the opcode based on the format values later. The clang builtin integer operands should probably be unsigned, but all the other mfma intrinsics are using signed.
1 parent 201f4f6 commit 7efa846

33 files changed

+10121
-33
lines changed

clang/include/clang/Basic/BuiltinsAMDGPU.def

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,11 @@ TARGET_BUILTIN(__builtin_amdgcn_cvt_sr_fp8_f32, "ifiiIi", "nc", "fp8-conversion-
434434
//===----------------------------------------------------------------------===//
435435
// GFX950 only builtins.
436436
//===----------------------------------------------------------------------===//
437+
TARGET_BUILTIN(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4, "V4fV8ZiV8ZiV4fIiIiIiiIii", "nc", "gfx950-insts")
438+
TARGET_BUILTIN(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4, "V16fV8ZiV8ZiV16fIiIiIiiIii", "nc", "gfx950-insts")
439+
437440
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_16x16x32_f16, "V4fV8hV8hV4fIiIiIi", "nc", "gfx950-insts")
438441
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_32x32x16_f16, "V16fV8hV8hV16fIiIiIi", "nc", "gfx950-insts")
439-
440442
TARGET_BUILTIN(__builtin_amdgcn_mfma_f32_32x32x16_bf16, "V16fV8yV8yV16fIiIiIi", "nc", "gfx950-insts")
441443

442444
//===----------------------------------------------------------------------===//

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19729,7 +19729,20 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1972919729
(uint64_t)0);
1973019730
return Builder.CreateInsertElement(I0, A, 1);
1973119731
}
19732-
19732+
case AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
19733+
case AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
19734+
llvm::FixedVectorType *VT = FixedVectorType::get(Builder.getInt32Ty(), 8);
19735+
Function *F = CGM.getIntrinsic(
19736+
BuiltinID == AMDGPU::BI__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4
19737+
? Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4
19738+
: Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4,
19739+
{VT, VT});
19740+
19741+
SmallVector<Value *, 9> Args;
19742+
for (unsigned I = 0, N = E->getNumArgs(); I != N; ++I)
19743+
Args.push_back(EmitScalarExpr(E->getArg(I)));
19744+
return Builder.CreateCall(F, Args);
19745+
}
1973319746
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
1973419747
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
1973519748
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:

clang/test/CodeGenOpenCL/builtins-amdgcn-mfma.cl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ typedef half v16h __attribute__((ext_vector_type(16)));
1616
typedef half v32h __attribute__((ext_vector_type(32)));
1717
typedef int v2i __attribute__((ext_vector_type(2)));
1818
typedef int v4i __attribute__((ext_vector_type(4)));
19+
typedef int v8i __attribute__((ext_vector_type(8)));
1920
typedef int v16i __attribute__((ext_vector_type(16)));
2021
typedef int v32i __attribute__((ext_vector_type(32)));
2122
typedef short v2s __attribute__((ext_vector_type(2)));
@@ -431,4 +432,18 @@ v16f test_mfma_f32_32x32x16_bf16(v8bf16 a, v8bf16 b, v16f c) {
431432
return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 1, 2, 3);
432433
}
433434

435+
// CHECK-GFX950-LABEL: @test_mfma_scale_f32_16x16x128_f8f6f4
436+
// CHECK-GFX950: call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <4 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
437+
void test_mfma_scale_f32_16x16x128_f8f6f4(global v4f* out, v8i a, v8i b, v4f c, int scale_a, int scale_b)
438+
{
439+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
440+
}
441+
442+
// CHECK-GFX950-LABEL: @test_mfma_scale_f32_32x32x64_f8f6f4
443+
// CHECK-GFX950: call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %a, <8 x i32> %b, <16 x float> %c, i32 3, i32 1, i32 2, i32 %scale_a, i32 3, i32 %scale_b)
444+
void test_mfma_scale_f32_32x32x64_f8f6f4(global v16f* out, v8i a, v8i b, v16f c, int scale_a, int scale_b)
445+
{
446+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 3, 1, 2, scale_a, 3, scale_b);
447+
}
448+
434449
#endif

clang/test/SemaOpenCL/builtins-amdgcn-error-gfx950-param.cl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ typedef float float4 __attribute__((ext_vector_type(4)));
55
typedef float float16 __attribute__((ext_vector_type(16)));
66
typedef half half8 __attribute__((ext_vector_type(8)));
77
typedef __bf16 bfloat8 __attribute__((ext_vector_type(8)));
8+
typedef int int8 __attribute__((ext_vector_type(8)));
89

910

1011
void test_mfma_f32_16x16x32_f16(__global float4* out, half8 a, half8 b, float4 c, int X) {
@@ -26,3 +27,17 @@ void test_mfma_f32_32x32x16_bf16(__global float16* out, bfloat8 a, bfloat8 b, fl
2627
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 0, X, 0); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
2728
*out = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a, b, c, 0, 0, X); // expected-error{{argument to '__builtin_amdgcn_mfma_f32_32x32x16_bf16' must be a constant integer}}
2829
}
30+
31+
void test_mfma_scale_f32_16x16x128_f8f6f4(__global float4* out, int8 a, int8 b, float4 c, int X, int Y) {
32+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, X, 0, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
33+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, X, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
34+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, 0, X, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
35+
*out = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, 0, 0, 0, Y, X, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' must be a constant integer}}
36+
}
37+
38+
void test_mfma_scale_f32_32x32x64_f8f6f4(__global float16* out, int8 a, int8 b, float16 c, int X, int Y) {
39+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, X, 0, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
40+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, X, 1, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
41+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, 0, X, Y, 2, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
42+
*out = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, 0, 0, 0, Y, X, Y); // expected-error{{argument to '__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' must be a constant integer}}
43+
}

clang/test/SemaOpenCL/builtins-amdgcn-error-gfx950.cl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,33 @@
44
typedef float float4 __attribute__((ext_vector_type(4)));
55
typedef float float16 __attribute__((ext_vector_type(16)));
66
typedef half half8 __attribute__((ext_vector_type(8)));
7+
typedef half half16 __attribute__((ext_vector_type(16)));
78
typedef __bf16 bfloat8 __attribute__((ext_vector_type(8)));
9+
typedef __bf16 bfloat16 __attribute__((ext_vector_type(16)));
10+
typedef unsigned int uint2 __attribute__((ext_vector_type(2)));
11+
typedef int int4 __attribute__((ext_vector_type(4)));
12+
typedef int int8 __attribute__((ext_vector_type(8)));
13+
typedef int int16 __attribute__((ext_vector_type(16)));
814

915
void test(__global float4* out0, half8 a0, half8 b0, float4 c0,
1016
__global float16* out1, half8 a1, half8 b1, float16 c1,
11-
__global float16* out2, bfloat8 a2, bfloat8 b2, float16 c2) {
17+
__global float16* out2, bfloat8 a2, bfloat8 b2, float16 c2,
18+
__global int4* out3, int4 a3, int4 b3, int4 c3,
19+
__global int16* out4, int4 a4, int4 b4, int16 c4,
20+
__global float4* out5, bfloat8 a5, bfloat8 b5, float4 c5,
21+
__global float4* out6, half8 a6, half16 b6, float4 c6,
22+
__global float16* out7, half8 a7, half16 b7, float16 c7,
23+
__global float4* out8, bfloat8 a8, bfloat16 b8, float4 c8,
24+
__global float16* out9, bfloat8 a9, bfloat16 b9, float16 c9,
25+
__global int4* out10, int4 a10, int8 b10, int4 c10,
26+
__global int16* out11, int4 a11, int8 b11, int16 c11,
27+
__global float4* out12, int4 a12, int8 b12, float4 c12,
28+
__global float16* out13, int4 a13, int8 b13, float16 c13,
29+
__global float4* out14, int8 a14, int8 b14, float4 c14, int d14, int e14,
30+
__global float16* out15, int8 a15, int8 b15, float16 c15, int d15, int e15) {
1231
*out0 = __builtin_amdgcn_mfma_f32_16x16x32_f16(a0, b0, c0, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_16x16x32_f16' needs target feature gfx950-insts}}
1332
*out1 = __builtin_amdgcn_mfma_f32_32x32x16_f16(a1, b1, c1, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_32x32x16_f16' needs target feature gfx950-insts}}
1433
*out2 = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a2, b2, c2, 0, 0, 0); // expected-error{{'__builtin_amdgcn_mfma_f32_32x32x16_bf16' needs target feature gfx950-insts}}
34+
*out14 = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a14, b14, c14, 0, 0, 0, d14, 0, e14); // expected-error{{'__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4' needs target feature gfx950-insts}}
35+
*out15 = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a15, b15, c15, 0, 0, 0, d15, 0, e15); // expected-error{{'__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4' needs target feature gfx950-insts}}
1536
}

llvm/docs/AMDGPUUsage.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,16 @@ The AMDGPU backend implements the following LLVM IR intrinsics.
13971397
used by hardware to control active lanes when used in EXEC register.
13981398
For example, ballot(i1 true) return EXEC mask.
13991399

1400+
llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4 Emit `v_mfma_scale_f32_16x16x128_f8f6f4` to set the scale factor. The
1401+
last 4 operands correspond to the scale inputs.
1402+
1403+
- 2-bit byte index to use for each lane for matrix A
1404+
- Matrix A scale values
1405+
- 2-bit byte index to use for each lane for matrix B
1406+
- Matrix B scale values
1407+
1408+
llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4 Emit `v_mfma_scale_f32_32x32x64_f8f6f4`
1409+
14001410
============================================== ==========================================================
14011411

14021412
.. TODO::

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2968,6 +2968,35 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
29682968
[IntrConvergent, IntrNoMem,
29692969
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>]>;
29702970

2971+
2972+
// srcA's format is determined by cbsz. srcB's format is determined by
2973+
// blgp.
2974+
//
2975+
// These should be <8 x i32> for f8 formats, <6 x i32> for f6 formats,
2976+
// and <4 x i32> for f4 formats. If the format control bits imply a
2977+
// smaller type than used, the high elements will be truncated.
2978+
//
2979+
// If the format control bits imply a larger type than used, the high
2980+
// elements are padded with undef.
2981+
2982+
class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy> :
2983+
DefaultAttrsIntrinsic<[DestTy],
2984+
[llvm_anyvector_ty, llvm_anyvector_ty, DestTy,
2985+
llvm_i32_ty, // cbsz
2986+
llvm_i32_ty, // blgp
2987+
// llvm_i1_ty, // TODO: neg_src2
2988+
// llvm_i1_ty, // TODO: abs_src2
2989+
// llvm_i1_ty, // TODO: clamp
2990+
llvm_i32_ty, // op_sel (A matrix scale, 2-bits) // TODO: Make i2?
2991+
llvm_i32_ty, // v_mfma_ld_scale_b32 src0 (A matrix scale)
2992+
llvm_i32_ty, // op_sel (B matrix scale, 2-bits) // TODO: Make i2?
2993+
llvm_i32_ty // v_mfma_ld_scale_b32 src1 (B matrix scale)
2994+
],
2995+
[IntrConvergent, IntrNoMem,
2996+
ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>,
2997+
ImmArg<ArgIndex<5>>, ImmArg<ArgIndex<7>>
2998+
]>;
2999+
29713000
defset list<Intrinsic> AMDGPUMFMAIntrinsics908 = {
29723001
def int_amdgcn_mfma_f32_32x32x1f32 : AMDGPUMfmaIntrinsic<llvm_v32f32_ty, llvm_float_ty>;
29733002
def int_amdgcn_mfma_f32_16x16x1f32 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_float_ty>;
@@ -3119,6 +3148,8 @@ def int_amdgcn_mfma_f32_16x16x32_f16 : AMDGPUMfmaIntrinsic<llvm_v4f32_ty, llvm_v
31193148
def int_amdgcn_mfma_f32_32x32x16_f16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8f16_ty>;
31203149

31213150
def int_amdgcn_mfma_f32_32x32x16_bf16 : AMDGPUMfmaIntrinsic<llvm_v16f32_ty, llvm_v8bf16_ty>;
3151+
def int_amdgcn_mfma_scale_f32_16x16x128_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v4f32_ty>;
3152+
def int_amdgcn_mfma_scale_f32_32x32x64_f8f6f4 : AMDGPUMfmaScaleIntrinsic<llvm_v16f32_ty>;
31223153
}
31233154

31243155
//===----------------------------------------------------------------------===//

llvm/lib/Target/AMDGPU/AMDGPUGISel.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,3 +423,6 @@ def gi_fp_pow2_to_exponent : GICustomOperandRenderer<"renderFPPow2ToExponent">,
423423

424424
def gi_as_hw_round_mode : GICustomOperandRenderer<"renderRoundMode">,
425425
GISDNodeXFormEquiv<as_hw_round_mode>;
426+
427+
def gi_MFMALdScaleModifierOp : GICustomOperandRenderer<"renderScaledMAIIntrinsicOperand">,
428+
GISDNodeXFormEquiv<MFMALdScaleXForm>;

llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,7 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
12581258
if (isa<UndefValue>(Src)) {
12591259
return IC.replaceInstUsesWith(II, Src);
12601260
}
1261+
return std::nullopt;
12611262
}
12621263
}
12631264
if (const AMDGPU::ImageDimIntrinsicInfo *ImageDimIntr =

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5737,6 +5737,18 @@ void AMDGPUInstructionSelector::renderRoundMode(MachineInstrBuilder &MIB,
57375737
MIB.addImm((MI.getOperand(OpIdx).getImm() + 3) % 4);
57385738
}
57395739

5740+
/// Convert from 2-bit value to enum values used for op_sel* source modifiers.
5741+
void AMDGPUInstructionSelector::renderScaledMAIIntrinsicOperand(
5742+
MachineInstrBuilder &MIB, const MachineInstr &MI, int OpIdx) const {
5743+
unsigned Val = MI.getOperand(OpIdx).getImm();
5744+
unsigned New = 0;
5745+
if (Val & 0x1)
5746+
New |= SISrcMods::OP_SEL_0;
5747+
if (Val & 0x2)
5748+
New |= SISrcMods::OP_SEL_1;
5749+
MIB.addImm(New);
5750+
}
5751+
57405752
bool AMDGPUInstructionSelector::isInlineImmediate(const APInt &Imm) const {
57415753
return TII.isInlineConstant(Imm);
57425754
}

llvm/lib/Target/AMDGPU/AMDGPUInstructionSelector.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,8 @@ class AMDGPUInstructionSelector final : public InstructionSelector {
364364

365365
void renderRoundMode(MachineInstrBuilder &MIB, const MachineInstr &MI,
366366
int OpIdx) const;
367+
void renderScaledMAIIntrinsicOperand(MachineInstrBuilder &MIB,
368+
const MachineInstr &MI, int OpIdx) const;
367369

368370
bool isInlineImmediate(const APInt &Imm) const;
369371
bool isInlineImmediate(const APFloat &Imm) const;

llvm/lib/Target/AMDGPU/AMDGPUInstructions.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class AMDGPUInst <dag outs, dag ins, string asm = "",
4040
// instructions to not match without killing the whole decode process. It is
4141
// mainly used for ARM, but Tablegen expects this field to exist or it fails
4242
// to build the decode table.
43-
field bits<96> SoftFail = 0;
43+
field bits<128> SoftFail = 0; // FIXME: If this is smaller than largest instruction, DecodeEmitter crashes
4444

4545
let DecoderNamespace = Namespace;
4646

llvm/lib/Target/AMDGPU/AMDGPURegisterBankInfo.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4769,6 +4769,25 @@ AMDGPURegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
47694769
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
47704770
break;
47714771
}
4772+
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
4773+
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
4774+
const SIMachineFunctionInfo *Info = MF.getInfo<SIMachineFunctionInfo>();
4775+
OpdsMapping[0] =
4776+
Info->mayNeedAGPRs()
4777+
? getAGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI)
4778+
: getVGPROpMapping(MI.getOperand(0).getReg(), MRI, *TRI);
4779+
4780+
OpdsMapping[2] = getVGPROpMapping(MI.getOperand(2).getReg(), MRI, *TRI);
4781+
OpdsMapping[3] = getVGPROpMapping(MI.getOperand(3).getReg(), MRI, *TRI);
4782+
OpdsMapping[4] =
4783+
Info->mayNeedAGPRs()
4784+
? getAGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI)
4785+
: getVGPROpMapping(MI.getOperand(4).getReg(), MRI, *TRI);
4786+
4787+
OpdsMapping[8] = getVGPROpMapping(MI.getOperand(8).getReg(), MRI, *TRI);
4788+
OpdsMapping[10] = getVGPROpMapping(MI.getOperand(10).getReg(), MRI, *TRI);
4789+
break;
4790+
}
47724791
case Intrinsic::amdgcn_smfmac_f32_16x16x32_f16:
47734792
case Intrinsic::amdgcn_smfmac_f32_32x32x16_f16:
47744793
case Intrinsic::amdgcn_smfmac_f32_16x16x32_bf16:

llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,17 @@ static inline DecoderUInt128 eat12Bytes(ArrayRef<uint8_t> &Bytes) {
512512
return DecoderUInt128(Lo, Hi);
513513
}
514514

515+
static inline DecoderUInt128 eat16Bytes(ArrayRef<uint8_t> &Bytes) {
516+
assert(Bytes.size() >= 16);
517+
uint64_t Lo =
518+
support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
519+
Bytes = Bytes.slice(8);
520+
uint64_t Hi =
521+
support::endian::read<uint64_t, llvm::endianness::little>(Bytes.data());
522+
Bytes = Bytes.slice(8);
523+
return DecoderUInt128(Lo, Hi);
524+
}
525+
515526
DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
516527
ArrayRef<uint8_t> Bytes_,
517528
uint64_t Address,
@@ -548,6 +559,15 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
548559

549560
// Reinitialize Bytes
550561
Bytes = Bytes_.slice(0, MaxInstBytesNum);
562+
563+
} else if (Bytes.size() >= 16 &&
564+
STI.hasFeature(AMDGPU::FeatureGFX950Insts)) {
565+
DecoderUInt128 DecW = eat16Bytes(Bytes);
566+
if (tryDecodeInst(DecoderTableGFX940128, MI, DecW, Address, CS))
567+
break;
568+
569+
// Reinitialize Bytes
570+
Bytes = Bytes_.slice(0, MaxInstBytesNum);
551571
}
552572

553573
if (Bytes.size() >= 8) {
@@ -759,6 +779,9 @@ DecodeStatus AMDGPUDisassembler::getInstruction(MCInst &MI, uint64_t &Size,
759779
if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::SDWA)
760780
convertSDWAInst(MI);
761781

782+
if (MCII->get(MI.getOpcode()).TSFlags & SIInstrFlags::IsMAI)
783+
convertMAIInst(MI);
784+
762785
int VDstIn_Idx = AMDGPU::getNamedOperandIdx(MI.getOpcode(),
763786
AMDGPU::OpName::vdst_in);
764787
if (VDstIn_Idx != -1) {
@@ -837,6 +860,58 @@ void AMDGPUDisassembler::convertSDWAInst(MCInst &MI) const {
837860
}
838861
}
839862

863+
/// Adjust the register values used by V_MFMA_F8F6F4_f8_f8 instructions to the
864+
/// appropriate subregister for the used format width.
865+
static void adjustMFMA_F8F6F4OpRegClass(const MCRegisterInfo &MRI,
866+
MCOperand &MO, uint8_t NumRegs) {
867+
switch (NumRegs) {
868+
case 4:
869+
return MO.setReg(MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3));
870+
case 6:
871+
return MO.setReg(
872+
MRI.getSubReg(MO.getReg(), AMDGPU::sub0_sub1_sub2_sub3_sub4_sub5));
873+
case 8:
874+
// No-op in cases where one operand is still f8/bf8.
875+
return;
876+
default:
877+
llvm_unreachable("Unexpected size for mfma f8f6f4 operand");
878+
}
879+
}
880+
881+
/// f8f6f4 instructions have different pseudos depending on the used formats. In
882+
/// the disassembler table, we only have the variants with the largest register
883+
/// classes which assume using an fp8/bf8 format for both operands. The actual
884+
/// register class depends on the format in blgp and cbsz operands. Adjust the
885+
/// register classes depending on the used format.
886+
void AMDGPUDisassembler::convertMAIInst(MCInst &MI) const {
887+
int BlgpIdx =
888+
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::blgp);
889+
if (BlgpIdx == -1)
890+
return;
891+
892+
int CbszIdx =
893+
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::cbsz);
894+
895+
unsigned CBSZ = MI.getOperand(CbszIdx).getImm();
896+
unsigned BLGP = MI.getOperand(BlgpIdx).getImm();
897+
898+
const AMDGPU::MFMA_F8F6F4_Info *AdjustedRegClassOpcode =
899+
AMDGPU::getMFMA_F8F6F4_WithFormatArgs(CBSZ, BLGP, MI.getOpcode());
900+
if (!AdjustedRegClassOpcode ||
901+
AdjustedRegClassOpcode->Opcode == MI.getOpcode())
902+
return;
903+
904+
MI.setOpcode(AdjustedRegClassOpcode->Opcode);
905+
int Src0Idx =
906+
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src0);
907+
int Src1Idx =
908+
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::src1);
909+
adjustMFMA_F8F6F4OpRegClass(MRI, MI.getOperand(Src0Idx),
910+
AdjustedRegClassOpcode->NumRegsSrcA);
911+
adjustMFMA_F8F6F4OpRegClass(MRI, MI.getOperand(Src1Idx),
912+
AdjustedRegClassOpcode->NumRegsSrcB);
913+
}
914+
840915
struct VOPModifiers {
841916
unsigned OpSel = 0;
842917
unsigned OpSelHi = 0;

llvm/lib/Target/AMDGPU/Disassembler/AMDGPUDisassembler.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class AMDGPUDisassembler : public MCDisassembler {
204204
void convertVINTERPInst(MCInst &MI) const;
205205
void convertFMAanyK(MCInst &MI, int ImmLitIdx) const;
206206
void convertSDWAInst(MCInst &MI) const;
207+
void convertMAIInst(MCInst &MI) const;
207208
void convertDPP8Inst(MCInst &MI) const;
208209
void convertMIMGInst(MCInst &MI) const;
209210
void convertVOP3DPPInst(MCInst &MI) const;

0 commit comments

Comments
 (0)