Skip to content

Commit 4564ac9

Browse files
yiqian1kuharjungpark-mlir
authored
Add gfx950 mfma instructions to ROCDL dialect (#123361)
Add ROCDL support to the following instructions: V_MFMA_F32_16X16X32_BF16 V_MFMA_I32_16X16X64_I8 V_MFMA_F32_16X16X32_F16 V_MFMA_F32_32X32X16_BF16 V_MFMA_I32_32X32X32_I8 V_MFMA_F32_32X32X16_F16 --------- Co-authored-by: Jakub Kuderski <[email protected]> Co-authored-by: Jungwook Park <[email protected]>
1 parent 6518b12 commit 4564ac9

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,6 @@ def ROCDL_mfma_i32_16x16x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x32.i8">;
379379
def ROCDL_mfma_i32_32x32x16_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x16.i8">;
380380
def ROCDL_mfma_f32_16x16x8_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x8.xf32">;
381381
def ROCDL_mfma_f32_32x32x4_xf32 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x4.xf32">;
382-
// fp8, only on gfx940
383382
def ROCDL_mfma_f32_16x16x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.bf8">;
384383
def ROCDL_mfma_f32_16x16x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf8.fp8">;
385384
def ROCDL_mfma_f32_16x16x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.fp8.bf8">;
@@ -388,6 +387,13 @@ def ROCDL_mfma_f32_32x32x16_bf8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.b
388387
def ROCDL_mfma_f32_32x32x16_bf8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf8.fp8">;
389388
def ROCDL_mfma_f32_32x32x16_fp8_bf8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.bf8">;
390389
def ROCDL_mfma_f32_32x32x16_fp8_fp8 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.fp8.fp8">;
390+
// New in gfx950.
391+
def ROCDL_mfma_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.bf16">;
392+
def ROCDL_mfma_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.16x16x64.i8">;
393+
def ROCDL_mfma_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.16x16x32.f16">;
394+
def ROCDL_mfma_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.bf16">;
395+
def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8">;
396+
def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16">;
391397

392398
//===---------------------------------------------------------------------===//
393399
// WMMA intrinsics

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,9 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
219219
%arg4 : vector<16 x f32>, %arg5 : vector<4xf32>,
220220
%arg6 : vector<4xf16>, %arg7 : vector<32 x i32>,
221221
%arg8 : vector<16 x i32>, %arg9 : vector<4xi32>,
222-
%arg10 : vector<2xi16>, %arg11 : i64) -> vector<32 x f32> {
222+
%arg10 : vector<2xi16>, %arg11 : i64,
223+
%arg12 : vector<8xbf16>, %arg13 : vector<4xi32>,
224+
%arg14 : vector<8xf16>) -> vector<32 x f32> {
223225
%csti32 = llvm.mlir.constant(42 : i32) : i32
224226

225227
// CHECK-LABEL: rocdl.xdlops
@@ -362,6 +364,37 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
362364
%r27 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11, %arg11, %arg4, %csti32, %csti32, %csti32 :
363365
(i64, i64, vector<16xf32>,
364366
i32, i32, i32) -> vector<16xf32>
367+
368+
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf16(<8 x bfloat> %{{.*}}, <8 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
369+
%r28 = rocdl.mfma.f32.16x16x32.bf16 %arg12, %arg12, %arg5, %csti32, %csti32, %csti32 :
370+
(vector<8xbf16>, vector<8xbf16>, vector<4xf32>,
371+
i32, i32, i32) -> vector<4xf32>
372+
373+
// CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x64.i8(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
374+
%r29 = rocdl.mfma.i32.16x16x64.i8 %arg9, %arg9, %arg9, %csti32, %csti32, %csti32 :
375+
(vector<4xi32>, vector<4xi32>, vector<4xi32>,
376+
i32, i32, i32) -> vector<4xi32>
377+
378+
// CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
379+
%r30 = rocdl.mfma.f32.16x16x32.f16 %arg14, %arg14, %arg5, %csti32, %csti32, %csti32 :
380+
(vector<8xf16>, vector<8xf16>, vector<4xf32>,
381+
i32, i32, i32) -> vector<4xi32>
382+
383+
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf16(<8 x bfloat> %1{{.*}}, <8 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
384+
%r31 = rocdl.mfma.f32.32x32x16.bf16 %arg12, %arg12, %arg4, %csti32, %csti32, %csti32 :
385+
(vector<8xbf16>, vector<8xbf16>, vector<16xf32>,
386+
i32, i32, i32) -> vector<16xf32>
387+
388+
// CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x32.i8(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
389+
%r32 = rocdl.mfma.i32.32x32x32.i8 %arg9, %arg9, %arg8, %csti32, %csti32, %csti32 :
390+
(vector<4xi32>, vector<4xi32>, vector<16xi32>,
391+
i32, i32, i32) -> vector<16xi32>
392+
393+
// CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
394+
%r33 = rocdl.mfma.f32.32x32x16.f16 %arg14, %arg14, %arg4, %csti32, %csti32, %csti32 :
395+
(vector<8xf16>, vector<8xf16>, vector<16xf32>,
396+
i32, i32, i32) -> vector<16xf32>
397+
365398
llvm.return %r0 : vector<32 x f32>
366399
}
367400

0 commit comments

Comments
 (0)