Skip to content

Commit 9e1f7a1

Browse files
committed
add option for trip count verification
Signed-off-by: Max Dawkins <[email protected]>
1 parent 5c0aae2 commit 9e1f7a1

File tree

5 files changed

+72
-76
lines changed

5 files changed

+72
-76
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,13 +2294,15 @@ def HoistRedundantVectorTransfersOp :
22942294
function op.
22952295
}];
22962296

2297-
let arguments = (ins TransformHandleTypeInterface:$target);
2297+
let arguments = (ins TransformHandleTypeInterface:$target,
2298+
UnitAttr:$verify_non_zero_trip);
22982299
let results = (outs TransformHandleTypeInterface:$transformed);
22992300

23002301
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results) ";
23012302

23022303
let builders = [
2303-
OpBuilder<(ins "Value":$target)>,
2304+
OpBuilder<(ins "Value":$target,
2305+
CArg<"bool", "false">:$verify_non_zero_trip)>,
23042306
];
23052307
let extraClassDeclaration = [{
23062308
::mlir::DiagnosedSilenceableFailure applyToOne(

mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ namespace linalg {
2929
/// 4. The source operands for vector.transfer_{read|write} do not originate
3030
/// from Ops implementing ViewLikeOpInterface (to reduce the risk of
3131
/// aliasing).
32+
/// 5. If `verifyNonZeroTrip` is true, then the lower bound of the loop must
33+
/// be statically smaller than the upper bound of the loop, guaranteeing that
34+
/// the loop body will execute at least once.
3235
/// To improve hoisting opportunities, call the `moveLoopInvariantCode` helper
3336
/// function on the candidate loop above which to hoist. Hoisting the transfers
3437
/// results in scf::ForOp yielding the value that originally transited through
@@ -41,7 +44,8 @@ namespace linalg {
4144
///
4245
/// WARNING: This hoisting does not model parallelism and is generally incorrect
4346
/// when used on distributed loops with memref semantics!
44-
void hoistRedundantVectorTransfers(Operation *root);
47+
void hoistRedundantVectorTransfers(Operation *root,
48+
bool verifyNonZeroTrip = false);
4549

4650
/// Hoist vector.extract/vector.broadcast pairs out of immediately enclosing
4751
/// scf::ForOp iteratively, if the following conditions are met:

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3558,7 +3558,7 @@ transform::HoistRedundantVectorTransfersOp::applyToOne(
35583558
// WARNING: This hoisting does not model parallelism and is generally
35593559
// incorrect when used on distributed loops with memref semantics!
35603560
// TODO: obsolete and should be retired.
3561-
linalg::hoistRedundantVectorTransfers(target);
3561+
linalg::hoistRedundantVectorTransfers(target, getVerifyNonZeroTrip());
35623562
results.push_back(target);
35633563
return DiagnosedSilenceableFailure::success();
35643564
}

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead,
199199
return true;
200200
}
201201

202-
void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
202+
void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
203+
bool verifyNonZeroTrip) {
203204
bool changed = true;
204205
while (changed) {
205206
changed = false;
@@ -213,39 +214,41 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
213214
// a potentially zero trip count loop may cause a vector transfer to be
214215
// executed when it shouldn't be.
215216
llvm::DenseSet<LoopLikeOpInterface> definiteNonZeroTripCountLoops;
216-
root->walk([&](LoopLikeOpInterface loopLike) {
217-
std::optional<SmallVector<OpFoldResult>> lbs =
218-
loopLike.getLoopLowerBounds();
219-
std::optional<SmallVector<OpFoldResult>> ubs =
220-
loopLike.getLoopUpperBounds();
221-
// If loop bounds cannot be found, assume possibly zero trip count.
222-
if (!lbs || !ubs) {
223-
return;
224-
}
225-
// Otherwise, use ValueBounds to find the maximum lower bound and
226-
// minimum upper bound. If the bounds are found, and maxLb is less
227-
// than the minUb, then the loop will not have zero trip count.
228-
for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
229-
FailureOr<int64_t> maxLb =
230-
ValueBoundsConstraintSet::computeConstantBound(
231-
presburger::BoundType::UB, /*var=*/lb,
232-
/*stopCondition=*/nullptr, /*closedUB=*/true);
233-
if (failed(maxLb)) {
234-
return;
235-
}
236-
FailureOr<int64_t> minUb =
237-
ValueBoundsConstraintSet::computeConstantBound(
238-
presburger::BoundType::LB, /*var=*/ub,
239-
/*stopCondition=*/nullptr);
240-
if (failed(minUb)) {
217+
if (verifyNonZeroTrip) {
218+
root->walk([&](LoopLikeOpInterface loopLike) {
219+
std::optional<SmallVector<OpFoldResult>> lbs =
220+
loopLike.getLoopLowerBounds();
221+
std::optional<SmallVector<OpFoldResult>> ubs =
222+
loopLike.getLoopUpperBounds();
223+
// If loop bounds cannot be found, assume possibly zero trip count.
224+
if (!lbs || !ubs) {
241225
return;
242226
}
243-
if (minUb.value() <= maxLb.value()) {
244-
return;
227+
// Otherwise, use ValueBounds to find the maximum lower bound and
228+
// minimum upper bound. If the bounds are found, and maxLb is less
229+
// than the minUb, then the loop will not have zero trip count.
230+
for (auto [lb, ub] : llvm::zip_equal(lbs.value(), ubs.value())) {
231+
FailureOr<int64_t> maxLb =
232+
ValueBoundsConstraintSet::computeConstantBound(
233+
presburger::BoundType::UB, /*var=*/lb,
234+
/*stopCondition=*/nullptr, /*closedUB=*/true);
235+
if (failed(maxLb)) {
236+
return;
237+
}
238+
FailureOr<int64_t> minUb =
239+
ValueBoundsConstraintSet::computeConstantBound(
240+
presburger::BoundType::LB, /*var=*/ub,
241+
/*stopCondition=*/nullptr);
242+
if (failed(minUb)) {
243+
return;
244+
}
245+
if (minUb.value() <= maxLb.value()) {
246+
return;
247+
}
248+
definiteNonZeroTripCountLoops.insert(loopLike);
245249
}
246-
definiteNonZeroTripCountLoops.insert(loopLike);
247-
}
248-
});
250+
});
251+
}
249252

250253
root->walk([&](vector::TransferReadOp transferRead) {
251254
if (!isa<MemRefType>(transferRead.getShapedType()))
@@ -259,7 +262,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root) {
259262
if (!isa_and_nonnull<scf::ForOp, affine::AffineForOp>(loop))
260263
return WalkResult::advance();
261264

262-
if (!definiteNonZeroTripCountLoops.contains(loop)) {
265+
if (verifyNonZeroTrip && !definiteNonZeroTripCountLoops.contains(loop)) {
263266
LLVM_DEBUG(DBGS() << "Loop may have zero trip count: " << *loop
264267
<< "\n");
265268
return WalkResult::advance();

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@
88
// CHECK-SAME: %[[MEMREF4:[a-zA-Z0-9]*]]: memref<?x?xf32>,
99
// CHECK-SAME: %[[MEMREF5:[a-zA-Z0-9]*]]: memref<?x?xf32>,
1010
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
11+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
12+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
1113
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
1214
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
1315
func.func @hoist_vector_transfer_pairs(
1416
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>, %memref2: memref<?x?xf32>,
1517
%memref3: memref<?x?xf32>, %memref4: memref<?x?xf32>, %memref5: memref<?x?xf32>,
16-
%val: index, %step: index, %cmp: i1) {
17-
%lb = arith.constant 0 : index
18-
%ub = arith.constant 16 : index
18+
%val: index, %lb : index, %ub : index, %step: index, %cmp: i1) {
1919
%c0 = arith.constant 0 : index
2020
%cst = arith.constant 0.0 : f32
2121

2222
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<1xf32>
23-
// CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
23+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>) {
2424
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<2xf32>
25-
// CHECK: scf.for %[[J:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
25+
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) -> (vector<1xf32>, vector<2xf32>) {
2626
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<3xf32>
2727
// CHECK: vector.transfer_read %{{.*}} : memref<?x?xf32>, vector<4xf32>
2828
// CHECK: "some_crippling_use"(%[[MEMREF4]]) : (memref<?x?xf32>) -> ()
@@ -92,15 +92,15 @@ module attributes {transform.with_named_sequence} {
9292
// CHECK-SAME: %[[MEMREF2:[a-zA-Z0-9]*]]: memref<?x?xf32>,
9393
// CHECK-SAME: %[[MEMREF3:[a-zA-Z0-9]*]]: memref<?x?xf32>,
9494
// CHECK-SAME: %[[VAL:[a-zA-Z0-9]*]]: index,
95+
// CHECK-SAME: %[[LB:[a-zA-Z0-9]*]]: index,
96+
// CHECK-SAME: %[[UB:[a-zA-Z0-9]*]]: index,
9597
// CHECK-SAME: %[[STEP:[a-zA-Z0-9]*]]: index,
9698
// CHECK-SAME: %[[RANDOM:[a-zA-Z0-9]*]]: index,
9799
// CHECK-SAME: %[[CMP:[a-zA-Z0-9]*]]: i1
98100
func.func @hoist_vector_transfer_pairs_disjoint(
99101
%memref0: memref<?x?xf32>, %memref1: memref<?x?xf32>,
100-
%memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index,
102+
%memref2: memref<?x?xf32>, %memref3: memref<?x?xf32>, %val: index, %lb : index, %ub : index,
101103
%step: index, %random_index : index, %cmp: i1) {
102-
%lb = arith.constant 0 : index
103-
%ub = arith.constant 16 : index
104104
%c0 = arith.constant 0 : index
105105
%c1 = arith.constant 1 : index
106106
%c3 = arith.constant 3 : index
@@ -110,9 +110,9 @@ func.func @hoist_vector_transfer_pairs_disjoint(
110110
// CHECK: vector.transfer_read %[[MEMREF2]]{{.*}} : memref<?x?xf32>, vector<3xf32>
111111
// CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
112112
// CHECK: vector.transfer_read %[[MEMREF3]]{{.*}} : memref<?x?xf32>, vector<4xf32>
113-
// CHECK: scf.for %[[I:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) ->
113+
// CHECK: scf.for %[[I:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
114114
// CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
115-
// CHECK: scf.for %[[J:.*]] = {{.*}} to {{.*}} step %[[STEP]] iter_args({{.*}}) ->
115+
// CHECK: scf.for %[[J:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args({{.*}}) ->
116116
// CHECK-SAME: (vector<3xf32>, vector<3xf32>, vector<4xf32>, vector<4xf32>) {
117117
// CHECK: vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
118118
// CHECK: vector.transfer_read %[[MEMREF1]]{{.*}} : memref<?x?xf32>, vector<2xf32>
@@ -309,18 +309,18 @@ module attributes {transform.with_named_sequence} {
309309
// -----
310310

311311
// CHECK-LABEL: func.func @no_hoisting_zero_trip_loop
312-
func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %arg1: memref<20xi32>, %lb: index, %ub: index) {
312+
func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %lb: index, %ub: index) {
313313
%c0_i32 = arith.constant 0 : i32
314314
%c0 = arith.constant 0 : index
315315
%c1 = arith.constant 1 : index
316316
// %lb and %ub are unbounded, so do not hoist.
317317

318318
// CHECK: scf.for {{.*}} {
319319
// CHECK-NEXT: vector.transfer_read
320-
// CHECK-NEXT: vector.transfer_write
320+
// CHECK-NEXT: "prevent.dce"
321321
scf.for %arg2 = %lb to %ub step %c1 {
322322
%read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
323-
vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
323+
"prevent.dce"(%read) : (vector<4xi32>) ->()
324324
}
325325

326326
// %lb_0 is in range [%lb, 8], and %ub_0 is in range [4, %ub].
@@ -330,24 +330,23 @@ func.func @no_hoisting_zero_trip_loop(%arg0: memref<20xi32>, %arg1: memref<20xi3
330330

331331
// CHECK: scf.for {{.*}} {
332332
// CHECK-NEXT: vector.transfer_read
333-
// CHECK-NEXT: vector.transfer_write
333+
// CHECK-NEXT: "prevent.dce"
334334
scf.for %arg2 = %lb_0 to %ub_0 step %c1 {
335335
%read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
336-
vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
336+
"prevent.dce"(%read) : (vector<4xi32>) ->()
337337
}
338338

339339
// %lb_1 is in range [%lb, 4], and %ub_1 is in range [8, %ub].
340340
// Since %lb_1 is guaranteed to be less than %ub_1, hoisting is possible.
341341
%lb_1 = affine.min affine_map<(d0) -> (d0, 4)>(%lb)
342342
%ub_1 = affine.max affine_map<(d0) -> (d0, 8)>(%ub)
343343

344-
// CHECK: vector.transfer_read
344+
// CHECK: vector.transfer_read
345345
// CHECK: scf.for {{.*}} {
346346
// CHECK-NEXT: "prevent.dce"
347347
scf.for %arg2 = %lb_1 to %ub_1 step %c1 {
348348
%read = vector.transfer_read %arg0[%c0], %c0_i32 {in_bounds = [true]} : memref<20xi32>, vector<4xi32>
349349
"prevent.dce"(%read) : (vector<4xi32>) ->()
350-
vector.transfer_write %read, %arg1[%c0] {in_bounds = [true]} : vector<4xi32>, memref<20xi32>
351350
}
352351
return
353352
}
@@ -356,7 +355,7 @@ module attributes {transform.with_named_sequence} {
356355
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
357356
%0 = transform.structured.match ops{["func.func"]} in %arg1
358357
: (!transform.any_op) -> !transform.any_op
359-
transform.structured.hoist_redundant_vector_transfers %0
358+
transform.structured.hoist_redundant_vector_transfers %0 { verify_non_zero_trip }
360359
: (!transform.any_op) -> !transform.any_op
361360
transform.yield
362361
}
@@ -492,7 +491,7 @@ module attributes {transform.with_named_sequence} {
492491
// CHECK: #[[$MAP4:.+]] = affine_map<()[s0] -> (s0 + 4)>
493492

494493
// CHECK-LABEL: func.func @hoist_vector_transfer_pairs_disjoint_dynamic
495-
// CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %[[I0:.+]]: index)
494+
// CHECK-SAME: (%[[BUFFER:.+]]: memref<?x?xf32>, %{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[I0:.+]]: index)
496495

497496
// CHECK: %[[PLUS1:.+]] = affine.apply #[[$MAP1]]()[%[[I0]]]
498497
// CHECK: %[[PLUS4:.+]] = affine.apply #[[$MAP4]]()[%[[I0]]]
@@ -507,9 +506,7 @@ module attributes {transform.with_named_sequence} {
507506
// CHECK: vector.transfer_write %{{.+}}, %[[BUFFER]][%[[I0]], %[[I0]]]
508507

509508
func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
510-
%buffer: memref<?x?xf32>, %step: index, %i0 : index) {
511-
%lb = arith.constant 0 : index
512-
%ub = arith.constant 16 : index
509+
%buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
513510
%cst = arith.constant 0.0 : f32
514511
%i1 = affine.apply affine_map<(d0) -> (d0 + 1)>(%i0)
515512
%i2 = affine.apply affine_map<(d0) -> (d0 + 4)>(%i0)
@@ -552,9 +549,7 @@ module attributes {transform.with_named_sequence} {
552549
// CHECK-COUNT-2: vector.transfer_write
553550

554551
func.func @hoist_vector_transfer_pairs_overlapping_dynamic(
555-
%buffer: memref<?x?xf32>, %step: index, %i0 : index) {
556-
%lb = arith.constant 0 : index
557-
%ub = arith.constant 16 : index
552+
%buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index) {
558553
%cst = arith.constant 0.0 : f32
559554
%i1 = affine.apply affine_map<(d0) -> (d0 + 3)>(%i0)
560555

@@ -594,9 +589,7 @@ module attributes {transform.with_named_sequence} {
594589
// CHECK: return
595590

596591
func.func @hoist_vector_transfer_pairs_disjoint_dynamic(
597-
%buffer: memref<?x?xf32>, %step: index, %i0 : index, %i1 : index) {
598-
%lb = arith.constant 0 : index
599-
%ub = arith.constant 16 : index
592+
%buffer: memref<?x?xf32>, %lb : index, %ub : index, %step: index, %i0 : index, %i1 : index) {
600593
%cst = arith.constant 0.0 : f32
601594
%i2 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16)>(%i1)
602595
%i3 = affine.apply affine_map<(d0) -> ((d0 floordiv 32) * 16 + 8)>(%i1)
@@ -633,7 +626,7 @@ module attributes {transform.with_named_sequence} {
633626
// Test hoisting of vector.extract/vector.broadcast pairs
634627

635628
// CHECK-LABEL: func.func @hoist_vector_broadcasts
636-
// CHECK-SAME: (%{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
629+
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
637630
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
638631
// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
639632
// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
@@ -642,9 +635,7 @@ module attributes {transform.with_named_sequence} {
642635
// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
643636
// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
644637

645-
func.func @hoist_vector_broadcasts(%step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
646-
%lb = arith.constant 0 : index
647-
%ub = arith.constant 16 : index
638+
func.func @hoist_vector_broadcasts(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>) -> vector<3x4xf32> {
648639
%bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
649640
%extract = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
650641
%use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
@@ -669,7 +660,7 @@ module attributes {transform.with_named_sequence} {
669660
// Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
670661

671662
// CHECK-LABEL: func.func @hoist_vector_broadcasts
672-
// CHECK-SAME: (%{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
663+
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
673664
// CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
674665
// CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
675666
// CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
@@ -678,9 +669,7 @@ module attributes {transform.with_named_sequence} {
678669
// CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
679670
// CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
680671

681-
func.func @hoist_vector_broadcasts_dynamic(%step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
682-
%lb = arith.constant 0 : index
683-
%ub = arith.constant 16 : index
672+
func.func @hoist_vector_broadcasts_dynamic(%lb : index, %ub : index, %step : index, %vec : vector<3x4xf32>, %pos: index) -> vector<3x4xf32> {
684673
%bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec) -> vector<3x4xf32> {
685674
%extract = vector.extract %iarg[%pos] : vector<4xf32> from vector<3x4xf32>
686675
%use = "some_use"(%extract) : (vector<4xf32>) -> vector<4xf32>
@@ -705,7 +694,7 @@ module attributes {transform.with_named_sequence} {
705694
// Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
706695

707696
// CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple
708-
// CHECK-SAME: (%{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
697+
// CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
709698
// CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
710699
// CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
711700
// CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
@@ -718,9 +707,7 @@ module attributes {transform.with_named_sequence} {
718707
// CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
719708
// CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
720709

721-
func.func @hoist_vector_broadcasts_multiple(%step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
722-
%lb = arith.constant 0 : index
723-
%ub = arith.constant 16 : index
710+
func.func @hoist_vector_broadcasts_multiple(%lb : index, %ub : index, %step : index, %vec1 : vector<3x4xf32>, %vec2 : vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
724711
%bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args(%iarg = %vec1, %iarg2 = %vec2) -> (vector<3x4xf32>, vector<3x5xf32>) {
725712
%extract1 = vector.extract %iarg[0] : vector<4xf32> from vector<3x4xf32>
726713
%extract2 = vector.extract %iarg2[1] : vector<5xf32> from vector<3x5xf32>

0 commit comments

Comments
 (0)