@@ -51,3 +51,54 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
51
51
52
52
func.return
53
53
}
54
+
55
+ // CHECK-LABEL: func @scaled_mfma_to_rocdl(
56
+ // CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU>, %[[ARG8:.*]]: f8E8M0FNU
57
+ func.func @scaled_mfma_to_rocdl (%arg0 : vector <16 xf32 >,
58
+ %arg1 : vector <4 xf32 >, %arg2 : vector <32 xf8 E4 M3 FN>,
59
+ %arg3 : vector <32 xf8 E5 M2 >, %arg4 : vector <32 xf6 E2 M3 FN>,
60
+ %arg5 : vector <32 xf6 E3 M2 FN>, %arg6 : vector <32 xf4 E2 M1 FN>,
61
+ %arg7 : vector <4 xf8 E8 M0 FNU>, %arg8 : f8E8M0FNU ) {
62
+
63
+ // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
64
+ // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
65
+ // CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
66
+ // CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
67
+
68
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
69
+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E4 M3 FN>, f8E8M0FNU , vector <32 xf8 E4 M3 FN>, vector <16 xf32 >
70
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
71
+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E4 M3 FN>, f8E8M0FNU , vector <32 xf8 E4 M3 FN>, vector <4 xf32 >
72
+
73
+ // CHECK: llvm.bitcast
74
+
75
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
76
+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E5 M2 >, f8E8M0FNU , vector <32 xf8 E5 M2 >, vector <16 xf32 >
77
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
78
+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E5 M2 >, f8E8M0FNU , vector <32 xf8 E5 M2 >, vector <4 xf32 >
79
+
80
+ // CHECK: llvm.bitcast
81
+
82
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
83
+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E2 M3 FN>, f8E8M0FNU , vector <32 xf6 E2 M3 FN>, vector <16 xf32 >
84
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
85
+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E2 M3 FN>, f8E8M0FNU , vector <32 xf6 E2 M3 FN>, vector <4 xf32 >
86
+
87
+ // CHECK: llvm.bitcast
88
+ // CHECK: llvm.mlir.constant(3 : i32) : i32
89
+
90
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
91
+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E3 M2 FN>, f8E8M0FNU , vector <32 xf6 E3 M2 FN>, vector <16 xf32 >
92
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
93
+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E3 M2 FN>, f8E8M0FNU , vector <32 xf6 E3 M2 FN>, vector <4 xf32 >
94
+
95
+ // CHECK: llvm.bitcast
96
+ // CHECK: llvm.mlir.constant(4 : i32) : i32
97
+
98
+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
99
+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, f8E8M0FNU , vector <32 xf4 E2 M1 FN>, vector <16 xf32 >
100
+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
101
+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, f8E8M0FNU , vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
102
+
103
+ func.return
104
+ }
0 commit comments