Skip to content

Commit 43a50de

Browse files
authored
[MLIR][ROCDL] Add GFX940 SMFMAC (2:4 sparsity) instructions to the ROCDL dialect (#124435)
# Overview This PR adds 2:4 structured sparsity (sparse A, dense B) matrix multiply instructions to ROCDL. # Testing I've added tests to Dialect/mlir and Target/mlir
1 parent 14ffff3 commit 43a50de

File tree

3 files changed

+194
-0
lines changed

3 files changed

+194
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,24 @@ def ROCDL_mfma_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"mfma.i32.32x32x32.i8">;
408408
def ROCDL_mfma_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"mfma.f32.32x32x16.f16">;
409409
def ROCDL_mfma_scale_f32_16x16x128_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.16x16x128.f8f6f4", [0,1]>;
410410
def ROCDL_mfma_scale_f32_32x32x64_f8f6f4 : ROCDL_Mfma_OO_IntrOp<"mfma.scale.f32.32x32x64.f8f6f4", [0,1]>;
411+
412+
// 2:4 Sparsity ops (GFX940)
413+
def ROCDL_smfmac_f32_16x16x32_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.f16">;
414+
def ROCDL_smfmac_f32_32x32x16_f16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.f16">;
415+
def ROCDL_smfmac_f32_16x16x32_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x32.bf16">;
416+
def ROCDL_smfmac_f32_32x32x16_bf16 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x16.bf16">;
417+
def ROCDL_smfmac_i32_16x16x64_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.16x16x64.i8">;
418+
def ROCDL_smfmac_i32_32x32x32_i8 : ROCDL_Mfma_IntrOp<"smfmac.i32.32x32x32.i8">;
419+
def ROCDL_smfmac_f32_16x16x64_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.bf8">;
420+
def ROCDL_smfmac_f32_16x16x64_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.bf8.fp8">;
421+
def ROCDL_smfmac_f32_16x16x64_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.bf8">;
422+
def ROCDL_smfmac_f32_16x16x64_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.16x16x64.fp8.fp8">;
423+
def ROCDL_smfmac_f32_32x32x32_bf8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.bf8">;
424+
def ROCDL_smfmac_f32_32x32x32_bf8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.bf8.fp8">;
425+
def ROCDL_smfmac_f32_32x32x32_fp8_bf8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.bf8">;
426+
def ROCDL_smfmac_f32_32x32x32_fp8_fp8 : ROCDL_Mfma_IntrOp<"smfmac.f32.32x32x32.fp8.fp8">;
427+
428+
411429
//===---------------------------------------------------------------------===//
412430
// WMMA intrinsics
413431
class ROCDL_Wmma_IntrOp<string mnemonic, list<int> overloadedOperands,

mlir/test/Dialect/LLVMIR/rocdl.mlir

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,93 @@ func.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
258258
llvm.return
259259
}
260260

261+
262+
llvm.func @rocdl.smfmac(%arg0 : i32,
263+
%arg1 : vector<4 x f16>,
264+
%arg2 : vector<8 x f16>,
265+
%arg3 : vector<4 x f32>,
266+
%arg4 : vector<16 x f32>,
267+
%arg5 : vector<4 x i16>,
268+
%arg6 : vector<8 x i16>,
269+
%arg7 : vector<2xi32>,
270+
%arg8 : vector<4xi32>,
271+
%arg9 : vector<16xi32>) -> vector<4 x f32> {
272+
%csti32 = llvm.mlir.constant(42 : i32) : i32
273+
274+
// CHECK-LABEL: rocdl.smfmac
275+
// CHECK: rocdl.smfmac.f32.16x16x32.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
276+
%r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 :
277+
(vector<4xf16>, vector<8xf16>, vector<4xf32>,
278+
i32, i32, i32) -> vector<4xf32>
279+
280+
// CHECK: rocdl.smfmac.f32.32x32x16.f16 %{{.*}} : (vector<4xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
281+
%r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 :
282+
(vector<4xf16>, vector<8xf16>, vector<16xf32>,
283+
i32, i32, i32) -> vector<16xf32>
284+
285+
// CHECK: rocdl.smfmac.f32.16x16x32.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
286+
%r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 :
287+
(vector<4xi16>, vector<8xi16>, vector<4xf32>,
288+
i32, i32, i32) -> vector<4xf32>
289+
290+
// CHECK: rocdl.smfmac.f32.32x32x16.bf16 %{{.*}} : (vector<4xi16>, vector<8xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
291+
%r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 :
292+
(vector<4xi16>, vector<8xi16>, vector<16xf32>,
293+
i32, i32, i32) -> vector<16xf32>
294+
295+
// CHECK: rocdl.smfmac.i32.16x16x64.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
296+
%r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 :
297+
(vector<2xi32>, vector<4xi32>, vector<4xi32>,
298+
i32, i32, i32) -> vector<4xi32>
299+
300+
// CHECK: rocdl.smfmac.i32.32x32x32.i8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
301+
%r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 :
302+
(vector<2xi32>, vector<4xi32>, vector<16xi32>,
303+
i32, i32, i32) -> vector<16xi32>
304+
305+
// CHECK: rocdl.smfmac.f32.16x16x64.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
306+
%r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
307+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
308+
i32, i32, i32) -> vector<4xf32>
309+
310+
// CHECK: rocdl.smfmac.f32.16x16x64.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
311+
%r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
312+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
313+
i32, i32, i32) -> vector<4xf32>
314+
315+
// CHECK: rocdl.smfmac.f32.16x16x64.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
316+
%r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
317+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
318+
i32, i32, i32) -> vector<4xf32>
319+
320+
// CHECK: rocdl.smfmac.f32.16x16x64.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
321+
%r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
322+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
323+
i32, i32, i32) -> vector<4xf32>
324+
325+
// CHECK: rocdl.smfmac.f32.32x32x32.bf8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
326+
%r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
327+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
328+
i32, i32, i32) -> vector<16xf32>
329+
330+
// CHECK: rocdl.smfmac.f32.32x32x32.bf8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
331+
%r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
332+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
333+
i32, i32, i32) -> vector<16xf32>
334+
335+
// CHECK: rocdl.smfmac.f32.32x32x32.fp8.bf8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
336+
%r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
337+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
338+
i32, i32, i32) -> vector<16xf32>
339+
340+
// CHECK: rocdl.smfmac.f32.32x32x32.fp8.fp8 %{{.*}} : (vector<2xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
341+
%r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
342+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
343+
i32, i32, i32) -> vector<16xf32>
344+
345+
llvm.return %r0 : vector<4 x f32>
346+
}
347+
261348
llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
262349
%arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
263350
%arg3 : vector<6xi32>, %arg4 : vector<4xi32>) {

mlir/test/Target/LLVMIR/rocdl.mlir

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,95 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
398398
llvm.return %r0 : vector<32 x f32>
399399
}
400400

401+
llvm.func @rocdl.smfmac(%arg0 : i32,
402+
%arg1 : vector<4 x f16>,
403+
%arg2 : vector<8 x f16>,
404+
%arg3 : vector<4 x f32>,
405+
%arg4 : vector<16 x f32>,
406+
%arg5 : vector<4 x i16>,
407+
%arg6 : vector<8 x i16>,
408+
%arg7 : vector<2xi32>,
409+
%arg8 : vector<4xi32>,
410+
%arg9 : vector<16xi32>) -> vector<4 x f32> {
411+
%csti32 = llvm.mlir.constant(42 : i32) : i32
412+
413+
// CHECK-LABEL: rocdl.smfmac
414+
415+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
416+
%r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1, %arg2, %arg3, %csti32, %csti32, %csti32 :
417+
(vector<4xf16>, vector<8xf16>, vector<4xf32>,
418+
i32, i32, i32) -> vector<4xf32>
419+
420+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
421+
%r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1, %arg2, %arg4, %csti32, %csti32, %csti32 :
422+
(vector<4xf16>, vector<8xf16>, vector<16xf32>,
423+
i32, i32, i32) -> vector<16xf32>
424+
425+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
426+
%r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5, %arg6, %arg3, %csti32, %csti32, %csti32 :
427+
(vector<4xi16>, vector<8xi16>, vector<4xf32>,
428+
i32, i32, i32) -> vector<4xf32>
429+
430+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
431+
%r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5, %arg6, %arg4, %csti32, %csti32, %csti32 :
432+
(vector<4xi16>, vector<8xi16>, vector<16xf32>,
433+
i32, i32, i32) -> vector<16xf32>
434+
435+
// CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x64.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
436+
%r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7, %arg8, %arg8, %csti32, %csti32, %csti32 :
437+
(vector<2xi32>, vector<4xi32>, vector<4xi32>,
438+
i32, i32, i32) -> vector<4xi32>
439+
440+
// CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x32.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
441+
%r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7, %arg8, %arg9, %csti32, %csti32, %csti32 :
442+
(vector<2xi32>, vector<4xi32>, vector<16xi32>,
443+
i32, i32, i32) -> vector<16xi32>
444+
445+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
446+
%r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
447+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
448+
i32, i32, i32) -> vector<4xf32>
449+
450+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
451+
%r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
452+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
453+
i32, i32, i32) -> vector<4xf32>
454+
455+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
456+
%r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
457+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
458+
i32, i32, i32) -> vector<4xf32>
459+
460+
// CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
461+
%r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7, %arg8, %arg3, %csti32, %csti32, %csti32 :
462+
(vector<2xi32>, vector<4xi32>, vector<4xf32>,
463+
i32, i32, i32) -> vector<4xf32>
464+
465+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
466+
%r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
467+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
468+
i32, i32, i32) -> vector<16xf32>
469+
470+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
471+
%r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
472+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
473+
i32, i32, i32) -> vector<16xf32>
474+
475+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
476+
%r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
477+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
478+
i32, i32, i32) -> vector<16xf32>
479+
480+
481+
// CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
482+
%r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7, %arg8, %arg4, %csti32, %csti32, %csti32 :
483+
(vector<2xi32>, vector<4xi32>, vector<16xf32>,
484+
i32, i32, i32) -> vector<16xf32>
485+
486+
llvm.return %r0 : vector<4 x f32>
487+
}
488+
489+
401490
llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4(%arg0 : i32,
402491
%arg1 : vector<16 x f32>, %arg2 : vector<8xi32>,
403492
%arg3 : vector<6xi32>, %arg4 : vector<4xi32>) -> vector<16 x f32> {

0 commit comments

Comments
 (0)