Skip to content

Commit 7ea1c39

Browse files
[mlir][Transforms] LISH: Improve bypass analysis for loop-like ops (#70623)
Improve the bypass analysis for loop-like ops. Until now, loop-like ops were treated like any other non-subset ops: they prevent hoisting of any sort because the analysis does not know which parts of a tensor init operand are accessed by the loop-like op. With this change, the analysis can look into loop-like ops and analyze which subset they are operating on.
1 parent 97c9c16 commit 7ea1c39

File tree

2 files changed

+91
-16
lines changed

2 files changed

+91
-16
lines changed

mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,10 @@ namespace {
120120
class MatchingSubsets {
121121
public:
122122
/// Insert a subset op.
123-
void insert(SubsetOpInterface op) {
123+
void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
124124
allSubsetOps.push_back(op);
125+
if (!collectHoistableOps)
126+
return;
125127
if (auto extractionOp =
126128
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
127129
insertExtractionOp(extractionOp);
@@ -148,6 +150,15 @@ class MatchingSubsets {
148150
});
149151
}
150152

153+
/// Populate subset ops starting from the given region iter_arg. Return
154+
/// "failure" if non-subset ops are found along the path to the loop yielding
155+
/// op or if there is no single path to the tied yielded operand. If
156+
/// `collectHoistableOps` is set to "false", subset ops are gathered
157+
/// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
158+
LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
159+
BlockArgument iterArg,
160+
bool collectHoistableOps = true);
161+
151162
private:
152163
/// Helper function for equivalence of tensor values. Since only insertion
153164
/// subset ops (that are also destination style ops) are followed when
@@ -225,18 +236,12 @@ static OpOperand *getSingleTerminatorUse(Value value) {
225236
return nullptr;
226237
}
227238

228-
/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
229-
/// loop-like op and index into loop-invariant subset locations. Return the
230-
/// newly created loop op (that has extra iter_args) or the original loop op if
231-
/// nothing was hoisted.
232-
static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
233-
BlockArgument iterArg) {
234-
IRRewriter rewriter(loopLike.getContext());
239+
LogicalResult
240+
MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
241+
BlockArgument iterArg,
242+
bool collectHoistableOps) {
235243
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
236-
auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
237-
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
238244
Value value = iterArg;
239-
MatchingSubsets subsets;
240245

241246
// Traverse use-def chain. Subset ops can be hoisted only if all ops along the
242247
// use-def chain starting from the region iter_arg are subset extraction or
@@ -249,36 +254,71 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
249254
Value nextValue = {};
250255

251256
for (OpOperand &use : value.getUses()) {
257+
if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
258+
// Subset ops in nested loops are collected to check if there are only
259+
// disjoint subset ops, but such subset ops are not subject to hoisting.
260+
// To hoist subset ops from nested loops, the hoisting transformation
261+
// should be run on the nested loop.
262+
auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
263+
if (!nestedIterArg)
264+
return failure();
265+
// Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
266+
// use-def chain starting at `nestedIterArg` and terminating in the
267+
// tied, yielding operand.
268+
if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
269+
/*collectHoistableOps=*/false)))
270+
return failure();
271+
nextValue = nestedLoop.getTiedLoopResult(&use);
272+
continue;
273+
}
274+
252275
auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
253276
if (!subsetOp)
254-
return loopLike;
255-
subsets.insert(subsetOp);
277+
return failure();
278+
insert(subsetOp);
256279

257280
if (auto insertionOp =
258281
dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
259282
// The value must be used as a destination. (In case of a source, the
260283
// entire tensor would be read, which would prevent any hoisting.)
261284
if (&use != &insertionOp.getDestinationOperand())
262-
return loopLike;
285+
return failure();
263286
// There must be a single use-def chain from the region iter_arg to the
264287
// terminator. I.e., only one insertion op. Branches are not supported.
265288
if (nextValue)
266-
return loopLike;
289+
return failure();
267290
nextValue = insertionOp.getUpdatedDestination();
268291
}
269292
}
270293

271294
// Nothing can be hoisted if the chain does not continue with loop yielding
272295
// op or a subset insertion op.
273296
if (!nextValue)
274-
return loopLike;
297+
return failure();
275298
value = nextValue;
276299
}
277300

278301
// Hoist only if the SSA use-def chain ends in the yielding terminator of the
279302
// loop and the yielded value is the `idx`-th operand. (I.e., there is no
280303
// swapping yield.)
281304
if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
305+
return failure();
306+
307+
return success();
308+
}
309+
310+
/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
311+
/// loop-like op and index into loop-invariant subset locations. Return the
312+
/// newly created loop op (that has extra iter_args) or the original loop op if
313+
/// nothing was hoisted.
314+
static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
315+
BlockArgument iterArg) {
316+
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
317+
auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
318+
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
319+
IRRewriter rewriter(loopLike.getContext());
320+
MatchingSubsets subsets;
321+
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
282322
return loopLike;
283323

284324
// Hoist all matching extraction-insertion pairs one-by-one.

mlir/test/Transforms/loop-invariant-subset-hoisting.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,38 @@ func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {
235235

236236
return %0 : tensor<?xf32>
237237
}
238+
239+
// -----
240+
241+
// CHECK-LABEL: func @nested_hoisting(
242+
// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
243+
func.func @nested_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
244+
%lb = "test.foo"() : () -> (index)
245+
%ub = "test.foo"() : () -> (index)
246+
%step = "test.foo"() : () -> (index)
247+
248+
// CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]][0] [5] [1]
249+
// CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][5] [5] [1]
250+
// CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]], %[[hoisted2:.*]] = %[[extract2]])
251+
%0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
252+
%1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
253+
// CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
254+
%2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
255+
%3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
256+
// CHECK: %[[for2:.*]]:2 = {{.*}} iter_args(%[[t2:.*]] = %[[t]], %[[hoisted2_nested:.*]] = %[[hoisted2]])
257+
%4 = scf.for %iv2 = %lb to %ub step %step iter_args(%t2 = %3) -> (tensor<?xf32>) {
258+
%5 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
259+
// CHECK: %[[foo2:.*]] = "test.foo"(%[[hoisted2_nested]])
260+
%6 = "test.foo"(%5) : (tensor<5xf32>) -> (tensor<5xf32>)
261+
%7 = tensor.insert_slice %6 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
262+
// CHECK: scf.yield %[[t2]], %[[foo2]]
263+
scf.yield %7 : tensor<?xf32>
264+
}
265+
// CHECK: scf.yield %[[for2]]#0, %[[foo]], %[[for2]]#1
266+
scf.yield %4 : tensor<?xf32>
267+
}
268+
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[5] [5] [1]
269+
// CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert]][0] [5] [1]
270+
// CHECK: return %[[insert2]]
271+
return %0 : tensor<?xf32>
272+
}

0 commit comments

Comments
 (0)