Skip to content

Commit 86d0c2c

Browse files
committed
[mlir] Don't hoist transfers from potentially zero trip loops
Signed-off-by: Max Dawkins <[email protected]>
1 parent 87645e9 commit 86d0c2c

File tree

2 files changed

+133
-20
lines changed

2 files changed

+133
-20
lines changed

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,45 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
208208
root->walk(
209209
[&](LoopLikeOpInterface loopLike) { moveLoopInvariantCode(loopLike); });
210210

211+
// Find all loops that are certain to have non zero trip count. Any loops
212+
// that are not part of this set cannot be hoisted from, since hoisting from
213+
// a potentially zero trip count loop may cause a vector transfer to be
214+
// executed when it shouldn't be.
215+
llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
216+
root->walk([&](LoopLikeOpInterface loopLike) {
217+
std::optional<SmallVector<OpFoldResult>> lbs =
218+
loopLike.getLoopLowerBounds();
219+
std::optional<SmallVector<OpFoldResult>> ubs =
220+
loopLike.getLoopUpperBounds();
221+
// If loop bounds cannot be found, assume possibly zero trip count.
222+
if (!lbs || !ubs) {
223+
return;
224+
}
225+
// Otherwise, use ValueBounds to find the maximum lower bound and
226+
// minimum upper bound. If the bounds are found, and maxLb is less
227+
// than the minUb, then the loop will not have zero trip count.
228+
for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
229+
FailureOr<int64_t> maxLb =
230+
ValueBoundsConstraintSet::computeConstantBound(
231+
presburger::BoundType::UB, /*var=*/lb,
232+
/*stopCondition=*/nullptr, /*closedUB=*/true);
233+
if (failed(maxLb)) {
234+
return;
235+
}
236+
FailureOr<int64_t> minUb =
237+
ValueBoundsConstraintSet::computeConstantBound(
238+
presburger::BoundType::LB, /*var=*/ub,
239+
/*stopCondition=*/nullptr);
240+
if (failed(minUb)) {
241+
return;
242+
}
243+
if (minUb.value() <= maxLb.value()) {
244+
return;
245+
}
246+
definiteNonZeroTripCountLoops.insert(loopLike);
247+
}
248+
});
249+
211250
root->walk([&](vector::TransferReadOp transferRead) {
212251
if (!isa<MemRefType>(transferRead.getShapedType()))
213252
return WalkResult::advance();
@@ -220,6 +259,12 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
220259
if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
221260
return WalkResult::advance();
222261

262+
if (!definiteNonZeroTripCountLoops.contains(loop)) {
263+
LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
264+
<< "\n");
265+
return WalkResult::advance();
266+
}
267+
223268
LLVM_DEBUG(DBGS() << "Candidate read: " << *transferRead.getOperation()
224269
<< "\n");
225270

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 88 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@
88
// CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
99
// CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
1010
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
11-
// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
12-
// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
1311
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
1412
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
1513
func.func @hoist_vector_transfer_pairs(
1614
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
1715
%memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
18-
%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
16+
%val: index, %step: index, %cmp: i1) {
17+
%lb = arith.constant 0 : index
18+
%ub = arith.constant 16 : index
1919
%c0 = arith.constant 0 : index
2020
%cst = arith.constant 0.0 : f32
2121

2222
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
23-
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
23+
// CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
2424
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
25-
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
25+
// CHECK: scf.for %[[J:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
2626
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
2727
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
2828
// CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
@@ -92,15 +92,15 @@ module attributes {transform.with_named_sequence} {
9292
// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
9393
// CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
9494
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
95-
// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
96-
// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
9795
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
9896
// CHECK-SAME: %[[RANDOM:[a-zA-Z0-9]*]]: index,
9997
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
10098
func.func @hoist_vector_transfer_pairs_disjoint(
10199
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>,
102-
%memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index, %lb : index, %ub : index,
100+
%memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index,
103101
%step: index, %random_index : index, %cmp: i1) {
102+
%lb = arith.constant 0 : index
103+
%ub = arith.constant 16 : index
104104
%c0 = arith.constant 0 : index
105105
%c1 = arith.constant 1 : index
106106
%c3 = arith.constant 3 : index
@@ -110,9 +110,9 @@ func.func @hoist_vector_transfer_pairs_disjoint(
110110
// CHECK: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
111111
// CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
112112
// CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
113-
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
113+
// CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) ->
114114
// CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
115-
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
115+
// CHECK: scf.for %[[J:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) ->
116116
// CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
117117
// CHECK: vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
118118
// CHECK: vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
@@ -308,6 +308,62 @@ module attributes {transform.with_named_sequence} {
308308

309309
// -----
310310

311+
// CHECK-LABEL: func.func @no_hoisting_zero_trip_loop
312+
func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %arg1: memref<20xi32>, %lb: index, %ub: index) {
313+
%c0_i32 = arith.constant 0 : i32
314+
%c0 = arith.constant 0 : index
315+
%c1 = arith.constant 1 : index
316+
// %lb and %ub are unbounded, so do not hoist.
317+
318+
// CHECK: scf.for {{.*}} {
319+
// CHECK-NEXT: vector.transfer_read
320+
// CHECK-NEXT: vector.transfer_write
321+
scf.for %arg2 = %lb to %ub step %c1 {
322+
%read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
323+
vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
324+
}
325+
326+
// %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub].
327+
// Since %lb_0 could be greater than %ub_0, do not hoist.
328+
%lb_0 = affine.min affine_map<(d0) -> (d0, 8)>(%lb)
329+
%ub_0 = affine.max affine_map<(d0) -> (d0, 4)>(%ub)
330+
331+
// CHECK: scf.for {{.*}} {
332+
// CHECK-NEXT: vector.transfer_read
333+
// CHECK-NEXT: vector.transfer_write
334+
scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
335+
%read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
336+
vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
337+
}
338+
339+
// %lb_1 is in range [%lb, 4], and %ub_1 is in range [8, %ub].
340+
// Since %lb_1 is guaranteed to be less than %ub_1, hoisting is possible.
341+
%lb_1 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
342+
%ub_1 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
343+
344+
// CHECK: vector.transfer_read
345+
// CHECK: scf.for {{.*}} {
346+
// CHECK-NEXT: "prevent.dce"
347+
scf.for %arg2 = %lb_1 to %ub_1 step %c1 {
348+
%read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
349+
"prevent.dce"(%read) : (vector<4xi32>) ->()
350+
vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
351+
}
352+
return
353+
}
354+
355+
module attributes {transform.with_named_sequence} {
356+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
357+
%0 = transform.structured.match ops{["func.func"]} in %arg1
358+
: (!transform.any_op) -> !transform.any_op
359+
transform.structured.hoist_redundant_vector_transfers %0
360+
: (!transform.any_op) -> !transform.any_op
361+
transform.yield
362+
}
363+
}
364+
365+
// -----
366+
311367
// Regression test - `vector.transfer_read` below should not be hoisted.
312368
// Indeed, %collapse_shape (written to by `vector.transfer_write`) and %alloca
313369
// (read by `vector.transfer_read`) alias.
@@ -436,7 +492,7 @@ module attributes {transform.with_named_sequence} {
436492
// CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)>
437493

438494
// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
439-
// CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index)
495+
// CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %[[I0:.+]]: index)
440496

441497
// CHECK: %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]]
442498
// CHECK: %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]]
@@ -451,7 +507,9 @@ module attributes {transform.with_named_sequence} {
451507
// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]]
452508

453509
func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
454-
%buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
510+
%buffer: memref<?x?xf32>, %step: index, %i0 : index) {
511+
%lb = arith.constant 0 : index
512+
%ub = arith.constant 16 : index
455513
%cst = arith.constant 0.0 : f32
456514
%i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
457515
%i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
@@ -494,7 +552,9 @@ module attributes {transform.with_named_sequence} {
494552
// CHECK-COUNT-2: vector.transfer_write
495553

496554
func.func @hoist_vector_transfer_pairs_overlapping_dynamic(
497-
%buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
555+
%buffer: memref<?x?xf32>, %step: index, %i0 : index) {
556+
%lb = arith.constant 0 : index
557+
%ub = arith.constant 16 : index
498558
%cst = arith.constant 0.0 : f32
499559
%i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
500560

@@ -534,7 +594,9 @@ module attributes {transform.with_named_sequence} {
534594
// CHECK: return
535595

536596
func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
537-
%buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) {
597+
%buffer: memref<?x?xf32>, %step: index, %i0 : index, %i1 : index) {
598+
%lb = arith.constant 0 : index
599+
%ub = arith.constant 16 : index
538600
%cst = arith.constant 0.0 : f32
539601
%i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
540602
%i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
@@ -571,7 +633,7 @@ module attributes {transform.with_named_sequence} {
571633
// Test hoisting of vector.extract/vector.broadcast pairs
572634

573635
// CHECK-LABEL: func.func @hoist_vector_broadcasts
574-
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
636+
// CHECK-SAME: (%{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
575637
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
576638
// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
577639
// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
@@ -580,7 +642,9 @@ module attributes {transform.with_named_sequence} {
580642
// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
581643
// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
582644

583-
func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
645+
func.func @hoist_vector_broadcasts(%step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
646+
%lb = arith.constant 0 : index
647+
%ub = arith.constant 16 : index
584648
%bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
585649
%extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
586650
%use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
@@ -605,7 +669,7 @@ module attributes {transform.with_named_sequence} {
605669
// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
606670

607671
// CHECK-LABEL: func.func @hoist_vector_broadcasts
608-
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
672+
// CHECK-SAME: (%{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
609673
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
610674
// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
611675
// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
@@ -614,7 +678,9 @@ module attributes {transform.with_named_sequence} {
614678
// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
615679
// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
616680

617-
func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
681+
func.func @hoist_vector_broadcasts_dynamic(%step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
682+
%lb = arith.constant 0 : index
683+
%ub = arith.constant 16 : index
618684
%bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
619685
%extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
620686
%use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
@@ -639,7 +705,7 @@ module attributes {transform.with_named_sequence} {
639705
// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
640706

641707
// CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple
642-
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
708+
// CHECK-SAME: (%{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
643709
// CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
644710
// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
645711
// CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
@@ -652,7 +718,9 @@ module attributes {transform.with_named_sequence} {
652718
// CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
653719
// CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
654720

655-
func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
721+
func.func @hoist_vector_broadcasts_multiple(%step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
722+
%lb = arith.constant 0 : index
723+
%ub = arith.constant 16 : index
656724
%bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
657725
%extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
658726
%extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>

0 commit comments

Comments
 (0)