8
8
9
9
// RUN: %{compile}
10
10
11
- // RUN: %{run} | FileCheck %s
11
+ // RUN: %{run} | FileCheck %s --check-prefix=F32
12
+
13
+ // REDEFINE: %{entry_point} = matmul_mixed_ty
14
+ // RUN: %{run} | FileCheck %s --check-prefix=MIXED
12
15
13
16
func.func @matmul_f32 () {
14
17
// Matrix dimensions
@@ -32,37 +35,75 @@ func.func @matmul_f32() {
32
35
%C_out = linalg.matmul ins (%A , %B: tensor <?x?xf32 >, tensor <?x?xf32 >) outs (%C_in: tensor <?x?xf32 >) -> tensor <?x?xf32 >
33
36
34
37
// Print and verify the output
35
- // CHECK -LABEL: SVE: START OF TEST OUTPUT
38
+ // F32 -LABEL: SVE: START OF TEST OUTPUT
36
39
vector.print str " SVE: START OF TEST OUTPUT"
37
40
38
- // CHECK -NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
39
- // CHECK -COUNT-5: [29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788]
41
+ // F32 -NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
42
+ // F32 -COUNT-5: [29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788, 29.5788]
40
43
%xf = tensor.cast %C_out : tensor <?x?xf32 > to tensor <*xf32 >
41
44
call @printMemrefF32 (%xf ) : (tensor <*xf32 >) -> ()
42
45
43
- // CHECK-NEXT: SVE: END OF TEST OUTPUT
46
+ // F32-NEXT: SVE: END OF TEST OUTPUT
47
+ vector.print str " SVE: END OF TEST OUTPUT"
48
+
49
+ return
50
+ }
51
+
52
+ func.func @matmul_mixed_ty () {
53
+ // Matrix dimensions
54
+ %K = arith.constant 3 : index
55
+ %M = arith.constant 5 : index
56
+ %N = arith.constant 15 : index
57
+ %c0_i8 = arith.constant 0 : i8
58
+ %c0_i32 = arith.constant 0 : i32
59
+
60
+ // Allocate the matrices
61
+ %A_alloc = bufferization.alloc_tensor (%M , %K ) : tensor <?x?xi8 >
62
+ %B_alloc = bufferization.alloc_tensor (%K , %N ) : tensor <?x?xi8 >
63
+ %C_alloc = bufferization.alloc_tensor (%M , %N ) : tensor <?x?xi32 >
64
+
65
+ // Initialise the matrices
66
+ %pi = arith.constant 123 : i8
67
+ %A = linalg.fill ins (%pi : i8 ) outs (%A_alloc : tensor <?x?xi8 >) -> tensor <?x?xi8 >
68
+ %B = linalg.fill ins (%pi : i8 ) outs (%B_alloc : tensor <?x?xi8 >) -> tensor <?x?xi8 >
69
+ %C_in = linalg.fill ins (%c0_i32 : i32 ) outs (%C_alloc : tensor <?x?xi32 >) -> tensor <?x?xi32 >
70
+
71
+ // Matmul
72
+ %C_out = linalg.matmul ins (%A , %B: tensor <?x?xi8 >, tensor <?x?xi8 >) outs (%C_in: tensor <?x?xi32 >) -> tensor <?x?xi32 >
73
+
74
+ // Print and verify the output
75
+ // MIXED-LABEL: SVE: START OF TEST OUTPUT
76
+ vector.print str " SVE: START OF TEST OUTPUT"
77
+
78
+ // MIXED-NEXT: Unranked Memref {{.*}} rank = 2 offset = 0 sizes = [5, 15] strides = [15, 1] data =
79
+ // MIXED-COUNT-5: [45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387, 45387]
80
+ %xf = tensor.cast %C_out : tensor <?x?xi32 > to tensor <*xi32 >
81
+ call @printMemrefI32 (%xf ) : (tensor <*xi32 >) -> ()
82
+
83
+ // MIXED-NEXT: SVE: END OF TEST OUTPUT
44
84
vector.print str " SVE: END OF TEST OUTPUT"
45
85
46
86
return
47
87
}
48
88
49
89
module attributes {transform.with_named_sequence } {
50
- transform.named_sequence @__transform_main (%module: !transform.any_op {transform.readonly }) {
51
- %matmul = transform.structured.match ops {[" linalg.matmul" ]} in %module
52
- : (!transform.any_op ) -> !transform.any_op
90
+ // A sequence that will tile and vectorise a Matmul Op
91
+ transform.named_sequence @tile_and_vectorize_matmul (%func
92
+ : !transform.op <" func.func" > {transform.readonly }) {
93
+
94
+ // Step 0: Get a handle to the matmul Op
95
+ %matmul = transform.structured.match ops {[" linalg.matmul" ]} in %func
96
+ : (!transform.op <" func.func" >) -> !transform.any_op
53
97
54
98
// Step 1: Tile
55
- %module_with_tiled_loops , %loops:3 = transform.structured.tile_using_for %matmul [2 , [4 ], 1 ]
99
+ %tiled_matmul , %loops:3 = transform.structured.tile_using_for %matmul [2 , [4 ], 1 ]
56
100
: (!transform.any_op ) -> (!transform.any_op , !transform.any_op , !transform.any_op , !transform.any_op )
101
+ transform.print %tiled_matmul {name = " matmul lal" }: !transform.any_op
57
102
58
103
// Step 2: Vectorize
59
- %tiled_matmul = transform.structured.match ops {[" linalg.matmul" ]} in %module_with_tiled_loops
60
- : (!transform.any_op ) -> !transform.any_op
61
104
transform.structured.vectorize %tiled_matmul vector_sizes [2 , [4 ], 1 ] : !transform.any_op
62
105
63
106
// Step 3: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
64
- %func = transform.structured.match ops {[" func.func" ]} in %module
65
- : (!transform.any_op ) -> !transform.op <" func.func" >
66
107
transform.apply_patterns to %func {
67
108
transform.apply_patterns.vector.reduction_to_contract
68
109
transform.apply_patterns.vector.transfer_permutation_patterns
@@ -77,6 +118,21 @@ transform.named_sequence @__transform_main(%module: !transform.any_op {transform
77
118
78
119
transform.yield
79
120
}
121
+
122
+ // A sequence that goes over all functions in tis module and applies
123
+ // "tile_and_vectorize_matmul"
124
+ transform.named_sequence @__transform_main (%module: !transform.any_op {transform.readonly }) {
125
+ %funcs = transform.structured.match ops {[" func.func" ]} in %module
126
+ : (!transform.any_op ) -> !transform.op <" func.func" >
127
+
128
+ transform.foreach %funcs : !transform.op <" func.func" > {
129
+ ^bb2 (%func : !transform.op <" func.func" >):
130
+ transform.include @tile_and_vectorize_matmul failures (propagate )
131
+ (%func ) : (!transform.op <" func.func" >) -> ()
132
+ }
133
+ transform.yield
134
+ }
80
135
}
81
136
82
137
func.func private @printMemrefF32 (%ptr : tensor <*xf32 >)
138
+ func.func private @printMemrefI32 (%ptr : tensor <*xi32 >)
0 commit comments