Skip to content

Commit a34bec5

Browse files
[fixup] Check we are generating the expected number and kind of LLVM intrinsics
1 parent b5865b5 commit a34bec5

File tree

5 files changed

+70
-62
lines changed

5 files changed

+70
-62
lines changed

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
1111
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
1212

13-
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
13+
// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
1414

1515
#packed_maps = [
1616
affine_map<(d0, d1, d2) -> (d0, d2)>,
@@ -20,6 +20,45 @@
2020

2121
func.func private @setArmVLBits(%bits : i32)
2222

23+
func.func private @prepareAccTestData(%in: vector<4x4xi32>) -> vector<4x[4]xi32> {
24+
%c0 = arith.constant 0 : index
25+
%c0_i32 = arith.constant 0 : i32
26+
27+
%mem = memref.alloca() : memref<4x4xi32>
28+
vector.transfer_write %in, %mem[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
29+
30+
%flat_mem = memref.collapse_shape %mem [[0, 1]] : memref<4x4xi32> into memref<16xi32>
31+
%flat_vec = vector.transfer_read %flat_mem[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
32+
%out = vector.shape_cast %flat_vec : vector<[16]xi32> to vector<4x[4]xi32>
33+
34+
return %out : vector<4x[4]xi32>
35+
}
36+
37+
func.func private @prepareLHSTestData(%in: vector<4x8xi8>) -> vector<4x8xi8> {
38+
%c0 = arith.constant 0 : index
39+
%c0_i8 = arith.constant 0 : i8
40+
41+
%mem = memref.alloca() : memref<4x8xi8>
42+
vector.transfer_write %in, %mem[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
43+
44+
%out = vector.transfer_read %mem[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
45+
46+
return %out : vector<4x8xi8>
47+
}
48+
49+
func.func private @prepareRHSTestData(%in: vector<4x8xi8>) -> vector<[32]xi8> {
50+
%c0 = arith.constant 0 : index
51+
%c0_i8 = arith.constant 0 : i8
52+
53+
%mem = memref.alloca() : memref<4x8xi8>
54+
vector.transfer_write %in, %mem[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
55+
56+
%flat_mem = memref.collapse_shape %mem [[0, 1]] : memref<4x8xi8> into memref<32xi8>
57+
%flat_vec = vector.transfer_read %flat_mem[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
58+
59+
return %flat_vec : vector<[32]xi8>
60+
}
61+
2362
func.func @main() {
2463
%c128 = arith.constant 128 : i32
2564
func.call @setArmVLBits(%c128) : (i32) -> ()
@@ -28,68 +67,32 @@ func.func @main() {
2867
%c0_i32 = arith.constant 0 : i32
2968
%c0_i8 = arith.constant 0 : i8
3069

31-
// Accumulator test data
70+
// Accumulator test data
3271
%acc_cst = arith.constant dense<[[-44, 20, 44, -46],
3372
[ -8, 25, -34, 26],
3473
[-20, -36, -3, 39],
3574
[-48, -31, -25, -21]]> : vector<4x4xi32>
36-
%acc_m = memref.alloca() : memref<4x4xi32>
37-
vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
38-
39-
%acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
40-
%acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
41-
%acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
42-
43-
vector.print str "ACC:\n"
44-
%acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
45-
%acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
46-
%acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
47-
%acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
48-
vector.print %acc0 : vector<[4]xi32>
49-
vector.print %acc1 : vector<[4]xi32>
50-
vector.print %acc2 : vector<[4]xi32>
51-
vector.print %acc3 : vector<[4]xi32>
75+
76+
%acc = func.call @prepareAccTestData(%acc_cst) : (vector<4x4xi32>) -> vector<4x[4]xi32>
5277

5378
// LHS test data
5479
%lhs_cst = arith.constant dense<[[-35, -27, -36, -31, 23, -34, -8, -33],
55-
[-20, 17, -32, -47, 37, 22, -7, -21],
56-
[ -7, -35, 20, -4, 39, 46, -23, 40],
57-
[ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
58-
59-
%lhs_m = memref.alloca() : memref<4x8xi8>
60-
vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
61-
%lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
62-
63-
vector.print str "LHS:\n"
64-
%lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
65-
%lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
66-
%lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
67-
%lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
68-
vector.print %lhs0 : vector<8xi8>
69-
vector.print %lhs1 : vector<8xi8>
70-
vector.print %lhs2 : vector<8xi8>
71-
vector.print %lhs3 : vector<8xi8>
80+
[-20, 17, -32, -47, 37, 22, -7, -21],
81+
[ -7, -35, 20, -4, 39, 46, -23, 40],
82+
[ 40, 27, 37, 43, 38, -6, 37, 49]]> : vector<4x8xi8>
83+
84+
%lhs = func.call @prepareLHSTestData(%lhs_cst) : (vector<4x8xi8>) -> vector<4x8xi8>
7285

7386
// RHS test data
7487
%rhs_cst = arith.constant dense<[[-17, -50, -1, 48, -13, 22, 39, 33],
7588
[-35, -24, 37, -32, 33, 30, -11, -17],
7689
[-28, 31, 3, -44, -15, -27, 22, 35],
7790
[-23, 39, 48, 26, -23, 32, -39, -38]]> : vector<4x8xi8>
78-
79-
%rhs_m = memref.alloca() : memref<4x8xi8>
80-
vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
81-
82-
%rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
83-
%rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
84-
85-
vector.print str "RHS:\n"
86-
%rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
87-
%rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
88-
vector.print %rhs0 : vector<[16]xi8>
89-
vector.print %rhs1 : vector<[16]xi8>
90-
91+
%rhs_flat = func.call @prepareRHSTestData(%rhs_cst) : (vector<4x8xi8>) -> vector<[32]xi8>
9192
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
9293

94+
// CHECK-IR-COUNT-4: arm_sve.intr.smmla
95+
9396
// Matrix multiplication
9497
%0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
9598
%1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
1111
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
1212

13-
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
13+
// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
1414

1515
#packed_maps = [
1616
affine_map<(d0, d1, d2) -> (d0, d2)>,
@@ -28,7 +28,6 @@ func.func @main() {
2828
%c0_i32 = arith.constant 0 : i32
2929
%c0_i8 = arith.constant 0 : i8
3030

31-
3231
// Accumulator test data
3332
%acc_cst = arith.constant dense<[[-44, 20, 44, -46, -8, 25, -34, 26],
3433
[-20, -36, -3, 39, -48, -31, -25, -21],
@@ -119,6 +118,8 @@ func.func @main() {
119118

120119
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
121120

121+
// CHECK-IR-COUNT-8: arm_sve.intr.smmla
122+
122123
// Matrix multiplication
123124
%0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
124125
%1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
@@ -146,7 +147,6 @@ func.func @main() {
146147
vector.print %u6 : vector<[4]xi32>
147148
vector.print %u7 : vector<[4]xi32>
148149

149-
150150
// CHECK: ( -2294, -1282, 2728, -410, -1328, 882, -5498, 732 )
151151
// CHECK: ( 1012, -4237, 4154, 2624, 5225, -2338, 2011, 1374 )
152152
// CHECK: ( -8, -1611, 2905, -1, -1068, -3155, -2428, 153 )

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
1111
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
1212

13-
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
13+
// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
1414

1515
#packed_maps = [
1616
affine_map<(d0, d1, d2) -> (d0, d2)>,
@@ -28,7 +28,7 @@ func.func @main() {
2828
%c0_i32 = arith.constant 0 : i32
2929
%c0_i8 = arith.constant 0 : i8
3030

31-
// Accumulator test data
31+
// Accumulator test data
3232
%acc_cst = arith.constant dense<[[-44, 20, 44, -46],
3333
[ -8, 25, -34, 26],
3434
[-20, -36, -3, 39],
@@ -90,6 +90,8 @@ func.func @main() {
9090

9191
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
9292

93+
// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
94+
9395
// Matrix multiplication
9496
%0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
9597
%1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
1111
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
1212

13-
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
13+
// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
1414

1515
#packed_maps = [
1616
affine_map<(d0, d1, d2) -> (d0, d2)>,
@@ -29,8 +29,7 @@ func.func @main() {
2929
%c0_i32 = arith.constant 0 : i32
3030
%c0_i8 = arith.constant 0 : i8
3131

32-
33-
// Accumulator test data
32+
// Accumulator test data
3433
%acc_cst = arith.constant dense<[[16, 16, 48, 40],
3534
[40, 24, 35, 12],
3635
[33, 24, 29, 19],
@@ -92,6 +91,8 @@ func.func @main() {
9291

9392
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
9493

94+
// CHECK-IR-COUNT-4: arm_sve.intr.ummla
95+
9596
// Matrix multiplication
9697
%0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
9798
%1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
1111
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
1212

13-
// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
13+
// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
1414

1515
#packed_maps = [
1616
affine_map<(d0, d1, d2) -> (d0, d2)>,
@@ -28,7 +28,7 @@ func.func @main() {
2828
%c0_i32 = arith.constant 0 : i32
2929
%c0_i8 = arith.constant 0 : i8
3030

31-
// Accumulator test data
31+
// Accumulator test data
3232
%acc_cst = arith.constant dense<[[-44, 20, 44, -46],
3333
[ -8, 25, -34, 26],
3434
[-20, -36, -3, 39],
@@ -90,6 +90,8 @@ func.func @main() {
9090

9191
%rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
9292

93+
// CHECK-IR-COUNT-4: arm_sve.intr.usmmla
94+
9395
// Matrix multiplication
9496
%0 = arith.extui %lhs : vector<4x8xi8> to vector<4x8xi32>
9597
%1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
@@ -109,9 +111,9 @@ func.func @main() {
109111
vector.print %u2 : vector<[4]xi32>
110112
vector.print %u3 : vector<[4]xi32>
111113

112-
// CHECK: ( 28403, 445, -2759, -11409 )
113-
// CHECK: ( 34908, 1047, 142, -7274 )
114-
// CHECK: ( 31032, 6807, -2378, 7382 )
115-
// CHECK: ( 44217, 6396, -10930, 623 )
114+
// CHECK: ( 28403, 445, -2759, -11409 )
115+
// CHECK: ( 34908, 1047, 142, -7274 )
116+
// CHECK: ( 31032, 6807, -2378, 7382 )
117+
// CHECK: ( 44217, 6396, -10930, 623 )
116118
return
117119
}

0 commit comments

Comments
 (0)