Skip to content

[mlir][Transforms] LISH: Improve bypass analysis for loop-like ops #70623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ namespace {
class MatchingSubsets {
public:
/// Insert a subset op.
void insert(SubsetOpInterface op) {
void insert(SubsetOpInterface op, bool collectHoistableOps = true) {
allSubsetOps.push_back(op);
if (!collectHoistableOps)
return;
if (auto extractionOp =
dyn_cast<SubsetExtractionOpInterface>(op.getOperation()))
insertExtractionOp(extractionOp);
Expand All @@ -148,6 +150,15 @@ class MatchingSubsets {
});
}

/// Populate subset ops starting from the given region iter_arg. Return
/// "failure" if non-subset ops are found along the path to the loop yielding
/// op or if there is no single path to the tied yielded operand. If
/// `collectHoistableOps` is set to "false", subset ops are gathered
/// throughout the traversal, but not enumerated by `getHoistableSubsetOps`.
LogicalResult populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
BlockArgument iterArg,
bool collectHoistableOps = true);

private:
/// Helper function for equivalence of tensor values. Since only insertion
/// subset ops (that are also destination style ops) are followed when
Expand Down Expand Up @@ -225,18 +236,12 @@ static OpOperand *getSingleTerminatorUse(Value value) {
return nullptr;
}

/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
/// loop-like op and index into loop-invariant subset locations. Return the
/// newly created loop op (that has extra iter_args) or the original loop op if
/// nothing was hoisted.
static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
BlockArgument iterArg) {
IRRewriter rewriter(loopLike.getContext());
LogicalResult
MatchingSubsets::populateSubsetOpsAtIterArg(LoopLikeOpInterface loopLike,
BlockArgument iterArg,
bool collectHoistableOps) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
Value value = iterArg;
MatchingSubsets subsets;

// Traverse use-def chain. Subset ops can be hoisted only if all ops along the
// use-def chain starting from the region iter_arg are subset extraction or
Expand All @@ -249,36 +254,71 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
Value nextValue = {};

for (OpOperand &use : value.getUses()) {
if (auto nestedLoop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
// Subset ops in nested loops are collected to check if there are only
// disjoint subset ops, but such subset ops are not subject to hoisting.
// To hoist subset ops from nested loops, the hoisting transformation
// should be run on the nested loop.
auto nestedIterArg = nestedLoop.getTiedLoopRegionIterArg(&use);
if (!nestedIterArg)
return failure();
// Note: `populateSubsetOpsAtIterArg` fails if there is no single SSA
// use-def chain starting at `nestedIterArg` and terminating in the
// tied, yielding operand.
if (failed(populateSubsetOpsAtIterArg(nestedLoop, nestedIterArg,
/*collectHoistableOps=*/false)))
return failure();
nextValue = nestedLoop.getTiedLoopResult(&use);
continue;
}

auto subsetOp = dyn_cast<SubsetOpInterface>(use.getOwner());
if (!subsetOp)
return loopLike;
subsets.insert(subsetOp);
return failure();
insert(subsetOp);

if (auto insertionOp =
dyn_cast<SubsetInsertionOpInterface>(use.getOwner())) {
// The value must be used as a destination. (In case of a source, the
// entire tensor would be read, which would prevent any hoisting.)
if (&use != &insertionOp.getDestinationOperand())
return loopLike;
return failure();
// There must be a single use-def chain from the region iter_arg to the
// terminator. I.e., only one insertion op. Branches are not supported.
if (nextValue)
return loopLike;
return failure();
nextValue = insertionOp.getUpdatedDestination();
}
}

// Nothing can be hoisted if the chain does not continue with loop yielding
// op or a subset insertion op.
if (!nextValue)
return loopLike;
return failure();
value = nextValue;
}

// Hoist only if the SSA use-def chain ends in the yielding terminator of the
// loop and the yielded value is the `idx`-th operand. (I.e., there is no
// swapping yield.)
if (loopLike.getTiedLoopYieldedValue(iterArg) != yieldedOperand)
return failure();

return success();
}

/// Hoist all subset ops that operate on the idx-th region iter_arg of the given
/// loop-like op and index into loop-invariant subset locations. Return the
/// newly created loop op (that has extra iter_args) or the original loop op if
/// nothing was hoisted.
static LoopLikeOpInterface hoistSubsetAtIterArg(LoopLikeOpInterface loopLike,
BlockArgument iterArg) {
assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg");
auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg);
int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it);
IRRewriter rewriter(loopLike.getContext());
MatchingSubsets subsets;
if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg)))
return loopLike;

// Hoist all matching extraction-insertion pairs one-by-one.
Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Transforms/loop-invariant-subset-hoisting.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,38 @@ func.func @non_loop_invariant_subset_op(%arg: tensor<?xf32>) -> tensor<?xf32> {

return %0 : tensor<?xf32>
}

// -----

// CHECK-LABEL: func @nested_hoisting(
// CHECK-SAME: %[[arg:.*]]: tensor<?xf32>
func.func @nested_hoisting(%arg: tensor<?xf32>) -> tensor<?xf32> {
%lb = "test.foo"() : () -> (index)
%ub = "test.foo"() : () -> (index)
%step = "test.foo"() : () -> (index)

// CHECK: %[[extract:.*]] = tensor.extract_slice %[[arg]][0] [5] [1]
// CHECK: %[[extract2:.*]] = tensor.extract_slice %[[arg]][5] [5] [1]
// CHECK: %[[for:.*]]:3 = scf.for {{.*}} iter_args(%[[t:.*]] = %[[arg]], %[[hoisted:.*]] = %[[extract]], %[[hoisted2:.*]] = %[[extract2]])
%0 = scf.for %iv = %lb to %ub step %step iter_args(%t = %arg) -> (tensor<?xf32>) {
%1 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
// CHECK: %[[foo:.*]] = "test.foo"(%[[hoisted]])
%2 = "test.foo"(%1) : (tensor<5xf32>) -> (tensor<5xf32>)
%3 = tensor.insert_slice %2 into %t[0][5][1] : tensor<5xf32> into tensor<?xf32>
// CHECK: %[[for2:.*]]:2 = {{.*}} iter_args(%[[t2:.*]] = %[[t]], %[[hoisted2_nested:.*]] = %[[hoisted2]])
%4 = scf.for %iv2 = %lb to %ub step %step iter_args(%t2 = %3) -> (tensor<?xf32>) {
%5 = tensor.extract_slice %t2[5][5][1] : tensor<?xf32> to tensor<5xf32>
// CHECK: %[[foo2:.*]] = "test.foo"(%[[hoisted2_nested]])
%6 = "test.foo"(%5) : (tensor<5xf32>) -> (tensor<5xf32>)
%7 = tensor.insert_slice %6 into %t2[5][5][1] : tensor<5xf32> into tensor<?xf32>
// CHECK: scf.yield %[[t2]], %[[foo2]]
scf.yield %7 : tensor<?xf32>
}
// CHECK: scf.yield %[[for2]]#0, %[[foo]], %[[for2]]#1
scf.yield %4 : tensor<?xf32>
}
// CHECK: %[[insert:.*]] = tensor.insert_slice %[[for]]#2 into %[[for]]#0[5] [5] [1]
// CHECK: %[[insert2:.*]] = tensor.insert_slice %[[for]]#1 into %[[insert]][0] [5] [1]
// CHECK: return %[[insert2]]
return %0 : tensor<?xf32>
}