Skip to content

Commit 5f8e7ed

Browse files
authored
[mlir][linalg][nfc] Split hoisting tests into dedicated test functions (#145234)
Refactors the `@hoist_vector_transfer_pairs` test function in `hoisting.mlir` into smaller, focused test functions - each covering a specific `vector.transfer_read`/`vector.transfer_write` pair. This makes it easier to identify which edge cases are tested, spot duplication, and write more targeted and readable check lines, with less surrounding noise. This refactor also helped identify some issues with the original `@hoist_vector_transfer_pairs` test: * Input variables `%val` and `%cmp` were unused. * There were no check lines for reads from `memref5`. **Note for reviewers (current and future):** This PR is split into small, incremental, and self-contained commits. It should be easier to follow the changes by reviewing those commits individually, rather than reading the full squashed diff. However, this will be merged as a single commit to avoid adding unnecessary history noise in-tree.
1 parent 369cbcc commit 5f8e7ed

File tree

1 file changed

+203
-63
lines changed

1 file changed

+203
-63
lines changed

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 203 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,210 @@
11
// RUN: mlir-opt -transform-interpreter -canonicalize --split-input-file --allow-unregistered-dialect %s | FileCheck %s
22

3-
// CHECK-LABEL: func @hoist_vector_transfer_pairs(
4-
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
5-
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,
6-
// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
7-
// CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
8-
// CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
9-
// CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
10-
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
11-
// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
12-
// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
13-
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
14-
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
15-
func.func @hoist_vector_transfer_pairs(
16-
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
17-
%memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
18-
%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
3+
///----------------------------------------------------------------------------------------
4+
/// Tests for vector.transfer_read + vector.transfer_write pairs
5+
///
6+
/// * Nested in double loops
7+
// * Indices depend on induction variables
8+
///----------------------------------------------------------------------------------------
9+
10+
// CHECK-LABEL: func @mem_use_outside
11+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
12+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
13+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
14+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
15+
func.func @mem_use_outside(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
16+
%pad = arith.constant 0.0 : f32
17+
18+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
19+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
20+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
21+
// CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
22+
// CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
23+
// CHECK: scf.yield %[[USE]] : vector<1xf32>
24+
// CHECK: }
25+
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]][%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
26+
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
27+
// CHECK: }
28+
scf.for %i = %lb to %ub step %step {
29+
scf.for %j = %lb to %ub step %step {
30+
%read = vector.transfer_read %mem[%i, %i], %pad: memref<?x?xf32>, vector<1xf32>
31+
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
32+
vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref<?x?xf32>
33+
}
34+
}
35+
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
36+
return
37+
}
38+
39+
module attributes {transform.with_named_sequence} {
40+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
41+
%0 = transform.structured.match ops{["func.func"]} in %arg1
42+
: (!transform.any_op) -> !transform.any_op
43+
transform.structured.hoist_redundant_vector_transfers %0
44+
: (!transform.any_op) -> !transform.any_op
45+
transform.yield
46+
}
47+
}
48+
49+
// -----
50+
51+
// CHECK-LABEL: func @mem_use_inside_outer_loop
52+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
53+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
54+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
55+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
56+
func.func @mem_use_inside_outer_loop(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
57+
%pad = arith.constant 0.0 : f32
58+
59+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
60+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
61+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]]{{\[}}%[[I]], %[[I]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
62+
// CHECK: %[[SCF:.*]] = scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[VAL_5:.*]] = %[[READ]]) -> (vector<1xf32>) {
63+
// CHECK: %[[USE:.*]] = "val_use"(%[[VAL_5]]) : (vector<1xf32>) -> vector<1xf32>
64+
// CHECK: scf.yield %[[USE]] : vector<1xf32>
65+
// CHECK: }
66+
// CHECK: vector.transfer_write %[[SCF]], %[[MEM]]{{\[}}%[[I]], %[[I]]] : vector<1xf32>, memref<?x?xf32>
67+
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
68+
// CHECK: }
69+
scf.for %i = %lb to %ub step %step {
70+
scf.for %j = %lb to %ub step %step {
71+
%read = vector.transfer_read %mem[%i, %i], %pad: memref<?x?xf32>, vector<1xf32>
72+
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
73+
vector.transfer_write %use, %mem[%i, %i] : vector<1xf32>, memref<?x?xf32>
74+
}
75+
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
76+
}
77+
return
78+
}
79+
80+
module attributes {transform.with_named_sequence} {
81+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
82+
%0 = transform.structured.match ops{["func.func"]} in %arg1
83+
: (!transform.any_op) -> !transform.any_op
84+
transform.structured.hoist_redundant_vector_transfers %0
85+
: (!transform.any_op) -> !transform.any_op
86+
transform.yield
87+
}
88+
}
89+
90+
// -----
91+
92+
///----------------------------------------------------------------------------------------
93+
/// Tests for vector.transfer_read + vector.transfer_write pairs
94+
///
95+
/// * Nested in double loops
96+
// * Indices are constant
97+
///----------------------------------------------------------------------------------------
98+
99+
// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_write
100+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
101+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
102+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
103+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
104+
func.func @negative_mem_use_inside_inner_loop_before_write(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
19105
%c0 = arith.constant 0 : index
20-
%cst = arith.constant 0.0 : f32
106+
%pad = arith.constant 0.0 : f32
107+
108+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
109+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
110+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
111+
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
112+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
113+
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
114+
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
115+
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
116+
// CHECK: }
117+
// CHECK: }
118+
scf.for %i = %lb to %ub step %step {
119+
scf.for %j = %lb to %ub step %step {
120+
%read = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
121+
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
122+
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
123+
vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
124+
}
125+
}
126+
return
127+
}
21128

22-
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
23-
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
24-
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
25-
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
26-
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
27-
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
28-
// CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
29-
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<5xf32>
30-
// CHECK: "some_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
31-
// CHECK: "some_use"(%{{.*}}) : (vector<2xf32>) -> vector<2xf32>
32-
// CHECK: "some_use"(%[[MEMREF2]], %{{.*}}) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
33-
// CHECK: "some_use"(%{{.*}}) : (vector<4xf32>) -> vector<4xf32>
34-
// CHECK: "some_use"(%{{.*}}) : (vector<5xf32>) -> vector<5xf32>
35-
// CHECK: vector.transfer_write %{{.*}} : vector<3xf32>, memref<?x?xf32>
36-
// CHECK: vector.transfer_write %{{.*}} : vector<4xf32>, memref<?x?xf32>
37-
// CHECK: vector.transfer_write %{{.*}} : vector<5xf32>, memref<?x?xf32>
38-
// CHECK: "some_crippling_use"(%[[MEMREF3]]) : (memref<?x?xf32>) -> ()
39-
// CHECK: scf.yield {{.*}} : vector<1xf32>, vector<2xf32>
129+
module attributes {transform.with_named_sequence} {
130+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
131+
%0 = transform.structured.match ops{["func.func"]} in %arg1
132+
: (!transform.any_op) -> !transform.any_op
133+
transform.structured.hoist_redundant_vector_transfers %0
134+
: (!transform.any_op) -> !transform.any_op
135+
transform.yield
136+
}
137+
}
138+
139+
// -----
140+
141+
// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_after_write
142+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
143+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
144+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
145+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
146+
func.func @negative_mem_use_inside_inner_loop_after_write(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
147+
%c0 = arith.constant 0 : index
148+
%pad = arith.constant 0.0 : f32
149+
150+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
151+
// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
152+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
153+
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
154+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[MEM]][%[[C0]], %[[C0]]], %[[PAD]] : memref<?x?xf32>, vector<1xf32>
155+
// CHECK: %[[USE:.*]] = "val_use"(%[[READ]]) : (vector<1xf32>) -> vector<1xf32>
156+
// CHECK: vector.transfer_write %[[USE]], %[[MEM]][%[[C0]], %[[C0]]] : vector<1xf32>, memref<?x?xf32>
157+
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
158+
// CHECK: }
159+
// CHECK: }
160+
scf.for %i = %lb to %ub step %step {
161+
scf.for %j = %lb to %ub step %step {
162+
%r3 = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
163+
%u3 = "val_use"(%r3) : (vector<1xf32>) -> vector<1xf32>
164+
vector.transfer_write %u3, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
165+
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
166+
}
167+
}
168+
return
169+
}
170+
171+
module attributes {transform.with_named_sequence} {
172+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
173+
%0 = transform.structured.match ops{["func.func"]} in %arg1
174+
: (!transform.any_op) -> !transform.any_op
175+
transform.structured.hoist_redundant_vector_transfers %0
176+
: (!transform.any_op) -> !transform.any_op
177+
transform.yield
178+
}
179+
}
180+
181+
// -----
182+
183+
// CHECK-LABEL: func @negative_mem_use_inside_inner_loop_before_read
184+
// CHECK-SAME: %[[MEM:[a-zA-Z0-9]+]]: memref<?x?xf32>,
185+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]+]]: index,
186+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]+]]: index,
187+
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]+]]: index)
188+
func.func @negative_mem_use_inside_inner_loop_before_read(%mem: memref<?x?xf32>, %lb : index, %ub : index, %step: index) {
189+
%c0 = arith.constant 0 : index
190+
%pad = arith.constant 0.0 : f32
191+
192+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
193+
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
194+
// CHECK: "mem_use"(%[[MEM]]) : (memref<?x?xf32>) -> ()
195+
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
196+
// CHECK: "val_use"(%{{.*}}) : (vector<1xf32>) -> vector<1xf32>
197+
// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
40198
// CHECK: }
41-
// CHECK: vector.transfer_write %{{.*}} : vector<2xf32>, memref<?x?xf32>
42-
// CHECK: "unrelated_use"(%[[MEMREF0]]) : (memref<?x?xf32>) -> ()
43-
// CHECK: scf.yield {{.*}} : vector<1xf32>
44199
// CHECK: }
45-
// CHECK: vector.transfer_write %{{.*}} : vector<1xf32>, memref<?x?xf32>
46-
// CHECK: "unrelated_use"(%[[MEMREF1]]) : (memref<?x?xf32>) -> ()
47200
scf.for %i = %lb to %ub step %step {
48201
scf.for %j = %lb to %ub step %step {
49-
%r0 = vector.transfer_read %memref1[%c0, %c0], %cst: memref<?x?xf32>, vector<1xf32>
50-
%r1 = vector.transfer_read %memref0[%i, %i], %cst: memref<?x?xf32>, vector<2xf32>
51-
%r2 = vector.transfer_read %memref2[%c0, %c0], %cst: memref<?x?xf32>, vector<3xf32>
52-
%r3 = vector.transfer_read %memref3[%c0, %c0], %cst: memref<?x?xf32>, vector<4xf32>
53-
"some_crippling_use"(%memref4) : (memref<?x?xf32>) -> ()
54-
%r4 = vector.transfer_read %memref4[%c0, %c0], %cst: memref<?x?xf32>, vector<5xf32>
55-
%r5 = vector.transfer_read %memref5[%c0, %c0], %cst: memref<?x?xf32>, vector<6xf32>
56-
"some_crippling_use"(%memref5) : (memref<?x?xf32>) -> ()
57-
%u0 = "some_use"(%r0) : (vector<1xf32>) -> vector<1xf32>
58-
%u1 = "some_use"(%r1) : (vector<2xf32>) -> vector<2xf32>
59-
%u2 = "some_use"(%memref2, %r2) : (memref<?x?xf32>, vector<3xf32>) -> vector<3xf32>
60-
%u3 = "some_use"(%r3) : (vector<4xf32>) -> vector<4xf32>
61-
%u4 = "some_use"(%r4) : (vector<5xf32>) -> vector<5xf32>
62-
%u5 = "some_use"(%r5) : (vector<6xf32>) -> vector<6xf32>
63-
vector.transfer_write %u0, %memref1[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
64-
vector.transfer_write %u1, %memref0[%i, %i] : vector<2xf32>, memref<?x?xf32>
65-
vector.transfer_write %u2, %memref2[%c0, %c0] : vector<3xf32>, memref<?x?xf32>
66-
vector.transfer_write %u3, %memref3[%c0, %c0] : vector<4xf32>, memref<?x?xf32>
67-
vector.transfer_write %u4, %memref4[%c0, %c0] : vector<5xf32>, memref<?x?xf32>
68-
vector.transfer_write %u5, %memref5[%c0, %c0] : vector<6xf32>, memref<?x?xf32>
69-
"some_crippling_use"(%memref3) : (memref<?x?xf32>) -> ()
202+
"mem_use"(%mem) : (memref<?x?xf32>) -> ()
203+
%read = vector.transfer_read %mem[%c0, %c0], %pad: memref<?x?xf32>, vector<1xf32>
204+
%use = "val_use"(%read) : (vector<1xf32>) -> vector<1xf32>
205+
vector.transfer_write %use, %mem[%c0, %c0] : vector<1xf32>, memref<?x?xf32>
70206
}
71-
"unrelated_use"(%memref0) : (memref<?x?xf32>) -> ()
72207
}
73-
"unrelated_use"(%memref1) : (memref<?x?xf32>) -> ()
74208
return
75209
}
76210

@@ -86,6 +220,12 @@ module attributes {transform.with_named_sequence} {
86220

87221
// -----
88222

223+
///----------------------------------------------------------------------------------------
224+
/// Other tests
225+
///
226+
/// TODO: Document
227+
///----------------------------------------------------------------------------------------
228+
89229
// CHECK-LABEL: func @hoist_vector_transfer_pairs_disjoint(
90230
// CHECK-SAME: %[[MEMREF0:[a-zA-Z0-9]*]]: memref<?x?xf32>,
91231
// CHECK-SAME: %[[MEMREF1:[a-zA-Z0-9]*]]: memref<?x?xf32>,

0 commit comments

Comments
 (0)