10
10
// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
11
11
// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
12
12
13
- // RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
13
+ // RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s
14
14
15
15
#packed_maps = [
16
16
affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>,
20
20
21
21
func.func private @setArmVLBits (%bits : i32 )
22
22
23
+ func.func private @prepareAccTestData (%in: vector <4 x4 xi32 >) -> vector <4 x[4 ]xi32 > {
24
+ %c0 = arith.constant 0 : index
25
+ %c0_i32 = arith.constant 0 : i32
26
+
27
+ %mem = memref.alloca () : memref <4 x4 xi32 >
28
+ vector.transfer_write %in , %mem [%c0 , %c0 ] : vector <4 x4 xi32 >, memref <4 x4 xi32 >
29
+
30
+ %flat_mem = memref.collapse_shape %mem [[0 , 1 ]] : memref <4 x4 xi32 > into memref <16 xi32 >
31
+ %flat_vec = vector.transfer_read %flat_mem [%c0 ], %c0_i32 {in_bounds = [true ]} : memref <16 xi32 >, vector <[16 ]xi32 >
32
+ %out = vector.shape_cast %flat_vec : vector <[16 ]xi32 > to vector <4 x[4 ]xi32 >
33
+
34
+ return %out : vector <4 x[4 ]xi32 >
35
+ }
36
+
37
+ func.func private @prepareLHSTestData (%in: vector <4 x8 xi8 >) -> vector <4 x8 xi8 > {
38
+ %c0 = arith.constant 0 : index
39
+ %c0_i8 = arith.constant 0 : i8
40
+
41
+ %mem = memref.alloca () : memref <4 x8 xi8 >
42
+ vector.transfer_write %in , %mem [%c0 , %c0 ] : vector <4 x8 xi8 >, memref <4 x8 xi8 >
43
+
44
+ %out = vector.transfer_read %mem [%c0 , %c0 ], %c0_i8 : memref <4 x8 xi8 >, vector <4 x8 xi8 >
45
+
46
+ return %out : vector <4 x8 xi8 >
47
+ }
48
+
49
+ func.func private @prepareRHSTestData (%in: vector <4 x8 xi8 >) -> vector <[32 ]xi8 > {
50
+ %c0 = arith.constant 0 : index
51
+ %c0_i8 = arith.constant 0 : i8
52
+
53
+ %mem = memref.alloca () : memref <4 x8 xi8 >
54
+ vector.transfer_write %in , %mem [%c0 , %c0 ] : vector <4 x8 xi8 >, memref <4 x8 xi8 >
55
+
56
+ %flat_mem = memref.collapse_shape %mem [[0 , 1 ]] : memref <4 x8 xi8 > into memref <32 xi8 >
57
+ %flat_vec = vector.transfer_read %flat_mem [%c0 ], %c0_i8 {in_bounds = [true ]} : memref <32 xi8 >, vector <[32 ]xi8 >
58
+
59
+ return %flat_vec : vector <[32 ]xi8 >
60
+ }
61
+
23
62
func.func @main () {
24
63
%c128 = arith.constant 128 : i32
25
64
func.call @setArmVLBits (%c128 ) : (i32 ) -> ()
@@ -28,68 +67,32 @@ func.func @main() {
28
67
%c0_i32 = arith.constant 0 : i32
29
68
%c0_i8 = arith.constant 0 : i8
30
69
31
- // Accumulator test data
70
+ // Accumulator test data
32
71
%acc_cst = arith.constant dense <[[-44 , 20 , 44 , -46 ],
33
72
[ -8 , 25 , -34 , 26 ],
34
73
[-20 , -36 , -3 , 39 ],
35
74
[-48 , -31 , -25 , -21 ]]> : vector <4 x4 xi32 >
36
- %acc_m = memref.alloca () : memref <4 x4 xi32 >
37
- vector.transfer_write %acc_cst , %acc_m [%c0 , %c0 ] : vector <4 x4 xi32 >, memref <4 x4 xi32 >
38
-
39
- %acc_m1 = memref.collapse_shape %acc_m [[0 , 1 ]] : memref <4 x4 xi32 > into memref <16 xi32 >
40
- %acc_flat = vector.transfer_read %acc_m1 [%c0 ], %c0_i32 {in_bounds = [true ]} : memref <16 xi32 >, vector <[16 ]xi32 >
41
- %acc = vector.shape_cast %acc_flat : vector <[16 ]xi32 > to vector <4 x[4 ]xi32 >
42
-
43
- vector.print str " ACC:\n "
44
- %acc0 = vector.extract %acc [0 ] : vector <[4 ]xi32 > from vector <4 x[4 ]xi32 >
45
- %acc1 = vector.extract %acc [1 ] : vector <[4 ]xi32 > from vector <4 x[4 ]xi32 >
46
- %acc2 = vector.extract %acc [2 ] : vector <[4 ]xi32 > from vector <4 x[4 ]xi32 >
47
- %acc3 = vector.extract %acc [3 ] : vector <[4 ]xi32 > from vector <4 x[4 ]xi32 >
48
- vector.print %acc0 : vector <[4 ]xi32 >
49
- vector.print %acc1 : vector <[4 ]xi32 >
50
- vector.print %acc2 : vector <[4 ]xi32 >
51
- vector.print %acc3 : vector <[4 ]xi32 >
75
+
76
+ %acc = func.call @prepareAccTestData (%acc_cst ) : (vector <4 x4 xi32 >) -> vector <4 x[4 ]xi32 >
52
77
53
78
// LHS test data
54
79
%lhs_cst = arith.constant dense <[[-35 , -27 , -36 , -31 , 23 , -34 , -8 , -33 ],
55
- [-20 , 17 , -32 , -47 , 37 , 22 , -7 , -21 ],
56
- [ -7 , -35 , 20 , -4 , 39 , 46 , -23 , 40 ],
57
- [ 40 , 27 , 37 , 43 , 38 , -6 , 37 , 49 ]]> : vector <4 x8 xi8 >
58
-
59
- %lhs_m = memref.alloca () : memref <4 x8 xi8 >
60
- vector.transfer_write %lhs_cst , %lhs_m [%c0 , %c0 ] : vector <4 x8 xi8 >, memref <4 x8 xi8 >
61
- %lhs = vector.transfer_read %lhs_m [%c0 , %c0 ], %c0_i8 : memref <4 x8 xi8 >, vector <4 x8 xi8 >
62
-
63
- vector.print str " LHS:\n "
64
- %lhs0 = vector.extract %lhs [0 ] : vector <8 xi8 > from vector <4 x8 xi8 >
65
- %lhs1 = vector.extract %lhs [1 ] : vector <8 xi8 > from vector <4 x8 xi8 >
66
- %lhs2 = vector.extract %lhs [2 ] : vector <8 xi8 > from vector <4 x8 xi8 >
67
- %lhs3 = vector.extract %lhs [3 ] : vector <8 xi8 > from vector <4 x8 xi8 >
68
- vector.print %lhs0 : vector <8 xi8 >
69
- vector.print %lhs1 : vector <8 xi8 >
70
- vector.print %lhs2 : vector <8 xi8 >
71
- vector.print %lhs3 : vector <8 xi8 >
80
+ [-20 , 17 , -32 , -47 , 37 , 22 , -7 , -21 ],
81
+ [ -7 , -35 , 20 , -4 , 39 , 46 , -23 , 40 ],
82
+ [ 40 , 27 , 37 , 43 , 38 , -6 , 37 , 49 ]]> : vector <4 x8 xi8 >
83
+
84
+ %lhs = func.call @prepareLHSTestData (%lhs_cst ) : (vector <4 x8 xi8 >) -> vector <4 x8 xi8 >
72
85
73
86
// RHS test data
74
87
%rhs_cst = arith.constant dense <[[-17 , -50 , -1 , 48 , -13 , 22 , 39 , 33 ],
75
88
[-35 , -24 , 37 , -32 , 33 , 30 , -11 , -17 ],
76
89
[-28 , 31 , 3 , -44 , -15 , -27 , 22 , 35 ],
77
90
[-23 , 39 , 48 , 26 , -23 , 32 , -39 , -38 ]]> : vector <4 x8 xi8 >
78
-
79
- %rhs_m = memref.alloca () : memref <4 x8 xi8 >
80
- vector.transfer_write %rhs_cst , %rhs_m [%c0 , %c0 ] : vector <4 x8 xi8 >, memref <4 x8 xi8 >
81
-
82
- %rhs_m1 = memref.collapse_shape %rhs_m [[0 , 1 ]] : memref <4 x8 xi8 > into memref <32 xi8 >
83
- %rhs_flat = vector.transfer_read %rhs_m1 [%c0 ], %c0_i8 {in_bounds = [true ]} : memref <32 xi8 >, vector <[32 ]xi8 >
84
-
85
- vector.print str " RHS:\n "
86
- %rhs0 = vector.scalable.extract %rhs_flat [0 ] : vector <[16 ]xi8 > from vector <[32 ]xi8 >
87
- %rhs1 = vector.scalable.extract %rhs_flat [16 ] : vector <[16 ]xi8 > from vector <[32 ]xi8 >
88
- vector.print %rhs0 : vector <[16 ]xi8 >
89
- vector.print %rhs1 : vector <[16 ]xi8 >
90
-
91
+ %rhs_flat = func.call @prepareRHSTestData (%rhs_cst ) : (vector <4 x8 xi8 >) -> vector <[32 ]xi8 >
91
92
%rhs = vector.shape_cast %rhs_flat : vector <[32 ]xi8 > to vector <[4 ]x8 xi8 >
92
93
94
+ // CHECK-IR-COUNT-4: arm_sve.intr.smmla
95
+
93
96
// Matrix multiplication
94
97
%0 = arith.extsi %lhs : vector <4 x8 xi8 > to vector <4 x8 xi32 >
95
98
%1 = arith.extsi %rhs : vector <[4 ]x8 xi8 > to vector <[4 ]x8 xi32 >
0 commit comments