1
- // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 | FileCheck %s
1
+ // RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx940 -cse | FileCheck %s
2
2
func.func @mfma_to_rocdl (%arg0 : f32 , %arg1 : vector <32 xf32 >,
3
3
%arg2 : vector <16 xf32 >, %arg3 : vector <4 xf32 >,
4
4
%arg4 : vector <4 xf16 >, %arg5 : vector <4 xi8 >,
@@ -28,7 +28,8 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
28
28
amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32 , cbsz = 0 : i32 , k = 8 : i32 , m = 32 : i32 , n = 32 : i32 , blocks = 1 : i32 } blgp = none : vector <4 xf16 >, vector <4 xf16 >, vector <16 xf32 >
29
29
// CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
30
30
amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32 , cbsz = 0 : i32 , k = 16 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 1 : i32 } blgp = none : vector <4 xf16 >, vector <4 xf16 >, vector <4 xf32 >
31
- // CHECK: rocdl.mfma.i32.32x32x4i8{{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
31
+ // CHECK: %[[BITCAST_4xi8_i32:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
32
+ // CHECK: rocdl.mfma.i32.32x32x4i8 %[[BITCAST_4xi8_i32]], %[[BITCAST_4xi8_i32]], {{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
32
33
amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32 , cbsz = 0 : i32 , k = 4 : i32 , m = 32 : i32 , n = 32 : i32 , blocks = 2 : i32 } blgp = none : vector <4 xi8 >, vector <4 xi8 >, vector <32 xi32 >
33
34
// CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
34
35
amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32 , cbsz = 0 : i32 , k = 4 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 4 : i32 } blgp = none : vector <4 xi8 >, vector <4 xi8 >, vector <16 xi32 >
@@ -38,7 +39,8 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
38
39
amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32 , cbsz = 0 : i32 , k = 8 : i32 , m = 32 : i32 , n = 32 : i32 , blocks = 1 : i32 } blgp = none : vector <4 xi8 >, vector <4 xi8 >, vector <16 xi32 >
39
40
// CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
40
41
amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32 , cbsz = 0 : i32 , k = 16 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 1 : i32 } blgp = none : vector <4 xi8 >, vector <4 xi8 >, vector <4 xi32 >
41
- // CHECK: rocdl.mfma.f32.32x32x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
42
+ // CHECK: %[[BITCAST_2xbf16_2xi16:.+]] = llvm.bitcast {{.*}} : vector<2xbf16> to vector<2xi16>
43
+ // CHECK: rocdl.mfma.f32.32x32x2bf16 %[[BITCAST_2xbf16_2xi16]], %[[BITCAST_2xbf16_2xi16]], %{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
42
44
amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32 , cbsz = 0 : i32 , k = 2 : i32 , m = 32 : i32 , n = 32 : i32 , blocks = 2 : i32 } blgp = none : vector <2 xbf16 >, vector <2 xbf16 >, vector <32 xf32 >
43
45
// CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
44
46
amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32 , cbsz = 0 : i32 , k = 2 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 4 : i32 } blgp = none : vector <2 xbf16 >, vector <2 xbf16 >, vector <16 xf32 >
@@ -48,7 +50,8 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
48
50
amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32 , cbsz = 0 : i32 , k = 4 : i32 , m = 32 : i32 , n = 32 : i32 , blocks = 1 : i32 } blgp = none : vector <2 xbf16 >, vector <2 xbf16 >, vector <16 xf32 >
49
51
// CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
50
52
amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32 , cbsz = 0 : i32 , k = 8 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 1 : i32 } blgp = none : vector <2 xbf16 >, vector <2 xbf16 >, vector <4 xf32 >
51
- // CHECK: rocdl.mfma.f32.32x32x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
53
+ // CHECK: %[[BITCAST_4xbf16_4xi16:.+]] = llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
54
+ // CHECK: rocdl.mfma.f32.32x32x4bf16.1k %[[BITCAST_4xbf16_4xi16]], %[[BITCAST_4xbf16_4xi16]], {{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
52
55
amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32 , cbsz = 0 : i32 , k = 4 : i32 , m = 32 : i32 , n = 32 : i32 , blocks = 2 : i32 } blgp = none : vector <4 xbf16 >, vector <4 xbf16 >, vector <32 xf32 >
53
56
// CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
54
57
amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32 , cbsz = 0 : i32 , k = 4 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 4 : i32 } blgp = none : vector <4 xbf16 >, vector <4 xbf16 >, vector <16 xf32 >
@@ -62,17 +65,20 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
62
65
amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32 , cbsz = 0 : i32 , k = 4 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 1 : i32 } blgp = none : f64 , f64 , vector <4 xf64 >
63
66
// CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64
64
67
amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32 , cbsz = 0 : i32 , k = 4 : i32 , m = 4 : i32 , n = 4 : i32 , blocks = 4 : i32 } blgp = none : f64 , f64 , f64
65
- // CHECK: rocdl.mfma.i32.16x16x32.i8{{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
68
+ // CHECK: %[[BITCAST_8xi8_i64:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
69
+ // CHECK: rocdl.mfma.i32.16x16x32.i8 %[[BITCAST_8xi8_i64]], %[[BITCAST_8xi8_i64]], {{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
66
70
amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32 , cbsz = 0 : i32 , k = 32 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 1 : i32 } blgp = none : vector <8 xi8 >, vector <8 xi8 >, vector <4 xi32 >
67
71
// CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
68
72
amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32 , cbsz = 0 : i32 , k = 16 : i32 , m = 32 : i32 , n = 32 : i32 , blocks = 1 : i32 } blgp = none : vector <8 xi8 >, vector <8 xi8 >, vector <16 xi32 >
69
73
// CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
70
74
amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32 , cbsz = 0 : i32 , k = 8 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 1 : i32 , reducePrecision } blgp = none : vector <2 xf32 >, vector <2 xf32 >, vector <4 xf32 >
71
75
// CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
72
76
amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32 , cbsz = 0 : i32 , k = 4 : i32 , m = 32 : i32 , n = 32 : i32 , blocks = 1 : i32 , reducePrecision } blgp = none : vector <2 xf32 >, vector <2 xf32 >, vector <16 xf32 >
73
- // CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
77
+ // CHECK: %[[BITCAST_8xi8_i64_1:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
78
+ // CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_1]], {{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
74
79
amdgpu.mfma %arg15 * %arg15 + %arg3 { abid = 0 : i32 , cbsz = 0 : i32 , k = 32 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 1 : i32 } blgp = none : vector <8 xf8 E5 M2 FNUZ>, vector <8 xf8 E5 M2 FNUZ>, vector <4 xf32 >
75
- // CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
80
+ // CHECK: %[[BITCAST_8xi8_i64_2:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
81
+ // CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_2]], {{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
76
82
amdgpu.mfma %arg15 * %arg16 + %arg3 { abid = 0 : i32 , cbsz = 0 : i32 , k = 32 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 1 : i32 } blgp = none : vector <8 xf8 E5 M2 FNUZ>, vector <8 xf8 E4 M3 FNUZ>, vector <4 xf32 >
77
83
// CHECK: rocdl.mfma.f32.16x16x32.fp8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
78
84
amdgpu.mfma %arg16 * %arg15 + %arg3 { abid = 0 : i32 , cbsz = 0 : i32 , k = 32 : i32 , m = 16 : i32 , n = 16 : i32 , blocks = 1 : i32 } blgp = none : vector <8 xf8 E4 M3 FNUZ>, vector <8 xf8 E5 M2 FNUZ>, vector <4 xf32 >
0 commit comments