Skip to content

Commit 0ea1156

Browse files
committed
address comments
Signed-off-by: Max Dawkins <[email protected]>
1 parent 9e1f7a1 commit 0ea1156

File tree

2 files changed

+62
-23
lines changed

2 files changed

+62
-23
lines changed

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,24 +221,23 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
221221
std::optional<SmallVector<OpFoldResult>> ubs =
222222
loopLike.getLoopUpperBounds();
223223
// If loop bounds cannot be found, assume possibly zero trip count.
224-
if (!lbs || !ubs) {
224+
if (!lbs || !ubs)
225225
return;
226-
}
226+
227227
// Otherwise, use ValueBounds to find the maximum lower bound and
228228
// minimum upper bound. If the bounds are found, and maxLb is less
229229
// than the minUb, then the loop will not have zero trip count.
230230
for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
231231
FailureOr<int64_t> maxLb =
232232
ValueBoundsConstraintSet::computeConstantBound(
233-
presburger::BoundType::UB, /*var=*/lb,
233+
presburger::BoundType::UB, lb,
234234
/*stopCondition=*/nullptr, /*closedUB=*/true);
235235
if (failed(maxLb)) {
236236
return;
237237
}
238238
FailureOr<int64_t> minUb =
239239
ValueBoundsConstraintSet::computeConstantBound(
240-
presburger::BoundType::LB, /*var=*/ub,
241-
/*stopCondition=*/nullptr);
240+
presburger::BoundType::LB, ub);
242241
if (failed(minUb)) {
243242
return;
244243
}

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -308,20 +308,40 @@ module attributes {transform.with_named_sequence} {
308308

309309
// -----
310310

311-
// CHECK-LABEL: func.func @no_hoisting_zero_trip_loop
312-
func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %lb: index, %ub: index) {
311+
// CHECK-LABEL: func.func @no_hoisting_unknown_bound_loop
312+
func.func @no_hoisting_unknown_bound_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
313313
%c0_i32 = arith.constant 0 : i32
314314
%c0 = arith.constant 0 : index
315315
%c1 = arith.constant 1 : index
316-
// %lb and %ub are unbounded, so do not hoist.
317316

317+
// %lb and %ub are unbounded, so do not hoist.
318318
// CHECK: scf.for {{.*}} {
319319
// CHECK-NEXT: vector.transfer_read
320-
// CHECK-NEXT: "prevent.dce"
320+
// CHECK-NEXT: "test.some_use"
321321
scf.for %arg2 = %lb to %ub step %c1 {
322-
%read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
323-
"prevent.dce"(%read) : (vector<4xi32>) ->()
322+
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
323+
"test.some_use"(%read) : (vector<4xi32>) ->()
324+
}
325+
return
326+
}
327+
328+
module attributes {transform.with_named_sequence} {
329+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
330+
%0 = transform.structured.match ops{["func.func"]} in %arg1
331+
: (!transform.any_op) -> !transform.any_op
332+
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
333+
: (!transform.any_op) -> !transform.any_op
334+
transform.yield
324335
}
336+
}
337+
338+
// -----
339+
340+
// CHECK-LABEL: func.func @no_hoisting_possibly_zero_trip_loop
341+
func.func @no_hoisting_possibly_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
342+
%c0_i32 = arith.constant 0 : i32
343+
%c0 = arith.constant 0 : index
344+
%c1 = arith.constant 1 : index
325345

326346
// %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub].
327347
// Since %lb_0 could be greater than %ub_0, do not hoist.
@@ -330,23 +350,43 @@ func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %lb: index, %ub: in
330350

331351
// CHECK: scf.for {{.*}} {
332352
// CHECK-NEXT: vector.transfer_read
333-
// CHECK-NEXT: "prevent.dce"
353+
// CHECK-NEXT: "test.some_use"
334354
scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
335-
%read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
336-
"prevent.dce"(%read) : (vector<4xi32>) ->()
355+
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
356+
"test.some_use"(%read) : (vector<4xi32>) ->()
337357
}
358+
return
359+
}
338360

339-
// %lb_1 is in range [%lb, 4], and %ub_1 is in range [8, %ub].
340-
// Since %lb_1 is guaranteed to be less than %ub_1, hoisting is possible.
341-
%lb_1 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
342-
%ub_1 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
361+
module attributes {transform.with_named_sequence} {
362+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
363+
%0 = transform.structured.match ops{["func.func"]} in %arg1
364+
: (!transform.any_op) -> !transform.any_op
365+
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
366+
: (!transform.any_op) -> !transform.any_op
367+
transform.yield
368+
}
369+
}
370+
371+
// -----
372+
373+
// CHECK-LABEL: func.func @hoisting_non_zero_trip_loop
374+
func.func @hoisting_non_zero_trip_loop(%memref0: memref<20xi32>, %lb: index, %ub: index) {
375+
%c0_i32 = arith.constant 0 : i32
376+
%c0 = arith.constant 0 : index
377+
%c1 = arith.constant 1 : index
378+
379+
// %lb_0 is in range [%lb, 4], and %ub_0 is in range [8, %ub].
380+
// Since %lb_0 is guaranteed to be less than %ub_0, hoisting is possible.
381+
%lb_0 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
382+
%ub_0 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
343383

344384
// CHECK: vector.transfer_read
345385
// CHECK: scf.for {{.*}} {
346-
// CHECK-NEXT: "prevent.dce"
347-
scf.for %arg2 = %lb_1 to %ub_1 step %c1 {
348-
%read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
349-
"prevent.dce"(%read) : (vector<4xi32>) ->()
386+
// CHECK-NEXT: "test.some_use"
387+
scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
388+
%read = vector.transfer_read %memref0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
389+
"test.some_use"(%read) : (vector<4xi32>) ->()
350390
}
351391
return
352392
}
@@ -421,7 +461,7 @@ func.func @no_hoisting_collapse_shape_2(%vec: vector<1x12x1xi32>) {
421461
%collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x12x1xi32> into memref<12xi32>
422462
vector.transfer_write %vec, %alloca[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x12x1xi32>, memref<1x12x1xi32>
423463
%read = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<12xi32>, vector<12xi32>
424-
"prevent.dce"(%read) : (vector<12xi32>) ->()
464+
"test.some_use"(%read) : (vector<12xi32>) ->()
425465
}
426466
return
427467
}

0 commit comments

Comments
 (0)