1
1
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
2
2
3
- // CHECK-LABEL: func @hoist_vector_transfer_pairs(
4
- // CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
5
- // CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
6
- // CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
7
- // CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
8
- // CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
9
- // CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
10
- // CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
11
- // CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
12
- // CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
13
- // CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
14
- // CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
15
- func.func @hoist_vector_transfer_pairs (
16
- %memref0: memref <?x?xf32 >, %memref1: memref <?x?xf32 >, %memref2: memref <?x?xf32 >,
17
- %memref3: memref <?x?xf32 >, %memref4: memref <?x?xf32 >, %memref5: memref <?x?xf32 >,
18
- %val: index , %lb : index , %ub : index , %step: index , %cmp: i1 ) {
3
+ ///----------------------------------------------------------------------------------------
4
+ /// Tests for vector.transfer_read + vector.transfer_write pairs
5
+ ///
6
+ /// * Nested in double loops
7
+ // * Indices depend on induction variables
8
+ ///----------------------------------------------------------------------------------------
9
+
10
+ // CHECK-LABEL: func @mem_use_outside
11
+ // CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
12
+ // CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
13
+ // CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
14
+ // CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
15
+ func.func @mem_use_outside (%mem: memref <?x?xf32 >, %lb : index , %ub : index , %step: index ) {
16
+ %pad = arith.constant 0.0 : f32
17
+
18
+ // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
19
+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
20
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
21
+ // CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
22
+ // CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
23
+ // CHECK: scf.yield %[[USE]] : vector<1xf32>
24
+ // CHECK: }
25
+ // CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
26
+ // CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
27
+ // CHECK: }
28
+ scf.for %i = %lb to %ub step %step {
29
+ scf.for %j = %lb to %ub step %step {
30
+ %read = vector.transfer_read %mem [%i , %i ], %pad: memref <?x?xf32 >, vector <1 xf32 >
31
+ %use = " val_use" (%read ) : (vector <1 xf32 >) -> vector <1 xf32 >
32
+ vector.transfer_write %use , %mem [%i , %i ] : vector <1 xf32 >, memref <?x?xf32 >
33
+ }
34
+ }
35
+ " mem_use" (%mem ) : (memref <?x?xf32 >) -> ()
36
+ return
37
+ }
38
+
39
+ module attributes {transform.with_named_sequence } {
40
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
41
+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
42
+ : (!transform.any_op ) -> !transform.any_op
43
+ transform.structured.hoist_redundant_vector_transfers %0
44
+ : (!transform.any_op ) -> !transform.any_op
45
+ transform.yield
46
+ }
47
+ }
48
+
49
+ // -----
50
+
51
+ // CHECK-LABEL: func @mem_use_inside_outer_loop
52
+ // CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
53
+ // CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
54
+ // CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
55
+ // CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
56
+ func.func @mem_use_inside_outer_loop (%mem: memref <?x?xf32 >, %lb : index , %ub : index , %step: index ) {
57
+ %pad = arith.constant 0.0 : f32
58
+
59
+ // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
60
+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
61
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
62
+ // CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
63
+ // CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
64
+ // CHECK: scf.yield %[[USE]] : vector<1xf32>
65
+ // CHECK: }
66
+ // CHECK: vector.transfer_write %[[SCF]], %[[MEM]]{{\[}}%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
67
+ // CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
68
+ // CHECK: }
69
+ scf.for %i = %lb to %ub step %step {
70
+ scf.for %j = %lb to %ub step %step {
71
+ %read = vector.transfer_read %mem [%i , %i ], %pad: memref <?x?xf32 >, vector <1 xf32 >
72
+ %use = " val_use" (%read ) : (vector <1 xf32 >) -> vector <1 xf32 >
73
+ vector.transfer_write %use , %mem [%i , %i ] : vector <1 xf32 >, memref <?x?xf32 >
74
+ }
75
+ " mem_use" (%mem ) : (memref <?x?xf32 >) -> ()
76
+ }
77
+ return
78
+ }
79
+
80
+ module attributes {transform.with_named_sequence } {
81
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
82
+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
83
+ : (!transform.any_op ) -> !transform.any_op
84
+ transform.structured.hoist_redundant_vector_transfers %0
85
+ : (!transform.any_op ) -> !transform.any_op
86
+ transform.yield
87
+ }
88
+ }
89
+
90
+ // -----
91
+
92
+ ///----------------------------------------------------------------------------------------
93
+ /// Tests for vector.transfer_read + vector.transfer_write pairs
94
+ ///
95
+ /// * Nested in double loops
96
+ // * Indices are constant
97
+ ///----------------------------------------------------------------------------------------
98
+
99
+ // CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_write
100
+ // CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
101
+ // CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
102
+ // CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
103
+ // CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
104
+ func.func @negative_mem_use_inside_inner_loop_before_write (%mem: memref <?x?xf32 >, %lb : index , %ub : index , %step: index ) {
19
105
%c0 = arith.constant 0 : index
20
- %cst = arith.constant 0.0 : f32
106
+ %pad = arith.constant 0.0 : f32
107
+
108
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
109
+ // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
110
+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
111
+ // CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
112
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
113
+ // CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
114
+ // CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
115
+ // CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
116
+ // CHECK: }
117
+ // CHECK: }
118
+ scf.for %i = %lb to %ub step %step {
119
+ scf.for %j = %lb to %ub step %step {
120
+ %read = vector.transfer_read %mem [%c0 , %c0 ], %pad: memref <?x?xf32 >, vector <1 xf32 >
121
+ %use = " val_use" (%read ) : (vector <1 xf32 >) -> vector <1 xf32 >
122
+ " mem_use" (%mem ) : (memref <?x?xf32 >) -> ()
123
+ vector.transfer_write %use , %mem [%c0 , %c0 ] : vector <1 xf32 >, memref <?x?xf32 >
124
+ }
125
+ }
126
+ return
127
+ }
21
128
22
- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
23
- // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
24
- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
25
- // CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
26
- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
27
- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
28
- // CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
29
- // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
30
- // CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
31
- // CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
32
- // CHECK: "some_use"(%[[MEMREF2]], %{{.*}}) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
33
- // CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
34
- // CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
35
- // CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
36
- // CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32>
37
- // CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32>
38
- // CHECK: "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> ()
39
- // CHECK: scf.yield {{.*}} : vector<1xf32>, vector<2xf32>
129
+ module attributes {transform.with_named_sequence } {
130
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
131
+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
132
+ : (!transform.any_op ) -> !transform.any_op
133
+ transform.structured.hoist_redundant_vector_transfers %0
134
+ : (!transform.any_op ) -> !transform.any_op
135
+ transform.yield
136
+ }
137
+ }
138
+
139
+ // -----
140
+
141
+ // CHECK-LABEL: func @negative_mem_use_inside_inner_loop_after_write
142
+ // CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
143
+ // CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
144
+ // CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
145
+ // CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
146
+ func.func @negative_mem_use_inside_inner_loop_after_write (%mem: memref <?x?xf32 >, %lb : index , %ub : index , %step: index ) {
147
+ %c0 = arith.constant 0 : index
148
+ %pad = arith.constant 0.0 : f32
149
+
150
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
151
+ // CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
152
+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
153
+ // CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
154
+ // CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
155
+ // CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
156
+ // CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
157
+ // CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
158
+ // CHECK: }
159
+ // CHECK: }
160
+ scf.for %i = %lb to %ub step %step {
161
+ scf.for %j = %lb to %ub step %step {
162
+ %r3 = vector.transfer_read %mem [%c0 , %c0 ], %pad: memref <?x?xf32 >, vector <1 xf32 >
163
+ %u3 = " val_use" (%r3 ) : (vector <1 xf32 >) -> vector <1 xf32 >
164
+ vector.transfer_write %u3 , %mem [%c0 , %c0 ] : vector <1 xf32 >, memref <?x?xf32 >
165
+ " mem_use" (%mem ) : (memref <?x?xf32 >) -> ()
166
+ }
167
+ }
168
+ return
169
+ }
170
+
171
+ module attributes {transform.with_named_sequence } {
172
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
173
+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
174
+ : (!transform.any_op ) -> !transform.any_op
175
+ transform.structured.hoist_redundant_vector_transfers %0
176
+ : (!transform.any_op ) -> !transform.any_op
177
+ transform.yield
178
+ }
179
+ }
180
+
181
+ // -----
182
+
183
+ // CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_read
184
+ // CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
185
+ // CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
186
+ // CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
187
+ // CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
188
+ func.func @negative_mem_use_inside_inner_loop_before_read (%mem: memref <?x?xf32 >, %lb : index , %ub : index , %step: index ) {
189
+ %c0 = arith.constant 0 : index
190
+ %pad = arith.constant 0.0 : f32
191
+
192
+ // CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
193
+ // CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
194
+ // CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
195
+ // CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
196
+ // CHECK: "val_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
197
+ // CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
40
198
// CHECK: }
41
- // CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32>
42
- // CHECK: "unrelated_use"(%[[MEMREF0]]) : (memref<?x?xf32>) -> ()
43
- // CHECK: scf.yield {{.*}} : vector<1xf32>
44
199
// CHECK: }
45
- // CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
46
- // CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref<?x?xf32>) -> ()
47
200
scf.for %i = %lb to %ub step %step {
48
201
scf.for %j = %lb to %ub step %step {
49
- %r0 = vector.transfer_read %memref1 [%c0 , %c0 ], %cst: memref <?x?xf32 >, vector <1 xf32 >
50
- %r1 = vector.transfer_read %memref0 [%i , %i ], %cst: memref <?x?xf32 >, vector <2 xf32 >
51
- %r2 = vector.transfer_read %memref2 [%c0 , %c0 ], %cst: memref <?x?xf32 >, vector <3 xf32 >
52
- %r3 = vector.transfer_read %memref3 [%c0 , %c0 ], %cst: memref <?x?xf32 >, vector <4 xf32 >
53
- " some_crippling_use" (%memref4 ) : (memref <?x?xf32 >) -> ()
54
- %r4 = vector.transfer_read %memref4 [%c0 , %c0 ], %cst: memref <?x?xf32 >, vector <5 xf32 >
55
- %r5 = vector.transfer_read %memref5 [%c0 , %c0 ], %cst: memref <?x?xf32 >, vector <6 xf32 >
56
- " some_crippling_use" (%memref5 ) : (memref <?x?xf32 >) -> ()
57
- %u0 = " some_use" (%r0 ) : (vector <1 xf32 >) -> vector <1 xf32 >
58
- %u1 = " some_use" (%r1 ) : (vector <2 xf32 >) -> vector <2 xf32 >
59
- %u2 = " some_use" (%memref2 , %r2 ) : (memref <?x?xf32 >, vector <3 xf32 >) -> vector <3 xf32 >
60
- %u3 = " some_use" (%r3 ) : (vector <4 xf32 >) -> vector <4 xf32 >
61
- %u4 = " some_use" (%r4 ) : (vector <5 xf32 >) -> vector <5 xf32 >
62
- %u5 = " some_use" (%r5 ) : (vector <6 xf32 >) -> vector <6 xf32 >
63
- vector.transfer_write %u0 , %memref1 [%c0 , %c0 ] : vector <1 xf32 >, memref <?x?xf32 >
64
- vector.transfer_write %u1 , %memref0 [%i , %i ] : vector <2 xf32 >, memref <?x?xf32 >
65
- vector.transfer_write %u2 , %memref2 [%c0 , %c0 ] : vector <3 xf32 >, memref <?x?xf32 >
66
- vector.transfer_write %u3 , %memref3 [%c0 , %c0 ] : vector <4 xf32 >, memref <?x?xf32 >
67
- vector.transfer_write %u4 , %memref4 [%c0 , %c0 ] : vector <5 xf32 >, memref <?x?xf32 >
68
- vector.transfer_write %u5 , %memref5 [%c0 , %c0 ] : vector <6 xf32 >, memref <?x?xf32 >
69
- " some_crippling_use" (%memref3 ) : (memref <?x?xf32 >) -> ()
202
+ " mem_use" (%mem ) : (memref <?x?xf32 >) -> ()
203
+ %read = vector.transfer_read %mem [%c0 , %c0 ], %pad: memref <?x?xf32 >, vector <1 xf32 >
204
+ %use = " val_use" (%read ) : (vector <1 xf32 >) -> vector <1 xf32 >
205
+ vector.transfer_write %use , %mem [%c0 , %c0 ] : vector <1 xf32 >, memref <?x?xf32 >
70
206
}
71
- " unrelated_use" (%memref0 ) : (memref <?x?xf32 >) -> ()
72
207
}
73
- " unrelated_use" (%memref1 ) : (memref <?x?xf32 >) -> ()
74
208
return
75
209
}
76
210
@@ -86,6 +220,12 @@ module attributes {transform.with_named_sequence} {
86
220
87
221
// -----
88
222
223
+ ///----------------------------------------------------------------------------------------
224
+ /// Other tests
225
+ ///
226
+ /// TODO: Document
227
+ ///----------------------------------------------------------------------------------------
228
+
89
229
// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint(
90
230
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
91
231
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
0 commit comments