Skip to content

[mlir] Add option for a cleanup pattern set to SCF tiling helper #109554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,23 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
let description = [{
Tiles the operations pointed to by the target handle and fuses their
producers greedily using the options provided as attributes.

If `apply_cleanup` is true then slice canonicalization is applied between
fusion steps.
}];

let arguments =
(ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
let results = (outs TransformHandleTypeInterface:$transformed,
Variadic<TransformHandleTypeInterface>:$loops);

let assemblyFormat = [{
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
attr-dict `:` functional-type(operands, results)
(`apply_cleanup` `=` $apply_cleanup^)? attr-dict
`:` functional-type(operands, results)
}];
let hasVerifier = 1;
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"

#include <deque>

Expand Down Expand Up @@ -153,6 +154,11 @@ struct SCFTileAndFuseOptions {
fusionControlFn = controlFn;
return *this;
}

/// An optional set of rewrite patterns to apply to the results of tiling
/// before fusion. This will track deleted and newly inserted
/// `tensor.extract_slice` ops and update the worklist.
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
};

/// Fuse the producer of the source of `candidateSliceOp` by computing the
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,15 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
scf::SCFTileAndFuseOptions tileAndFuseOptions;
tileAndFuseOptions.tilingOptions = tilingOptions;

if (getApplyCleanup()) {
MLIRContext *context = rewriter.getContext();
RewritePatternSet patterns(context);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
}

LogicalResult result = applyTilingToAll(
rewriter, getOperation(), state.getPayloadOps(getTarget()),
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
Expand Down
156 changes: 130 additions & 26 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
Expand Down Expand Up @@ -1315,6 +1317,104 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
return generatedSlices;
}

namespace {

//===----------------------------------------------------------------------===//
// SliceTrackingListener
//===----------------------------------------------------------------------===//

/// This class is a listener for tracking the insertion and removal of
/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
/// fusion algorithm to apply cleanup patterns in between fusion steps.
class SliceTrackingListener : public RewriterBase::Listener {
public:
explicit SliceTrackingListener(
std::optional<FrozenRewritePatternSet> patterns);
SliceTrackingListener() = default;

/// Adds the given list of operations to the worklist, and if present, applies
/// the list of `patterns` to the newly added operations. This only processes
/// the given operations and any newly inserted ones by the pattern set.
LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);

/// Add to the new operation worklist if it is an extract_slice.
void notifyOperationInserted(Operation *op,
OpBuilder::InsertPoint previous) override;

/// Shared helper for operation removal from the worklist.
void removeOp(Operation *op);

/// Remove the operation from the worklist.
void notifyOperationErased(Operation *op) override;

/// Remove the operation from the worklist.
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;

/// The worklist for this transformation keeps track of the slices to visit
/// next for fusion.
std::deque<tensor::ExtractSliceOp> worklist;

private:
/// Optional pattern set to apply when adding new operations to the worklist.
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
};

SliceTrackingListener::SliceTrackingListener(
std::optional<FrozenRewritePatternSet> p) {
patterns = std::move(p);
}

LogicalResult
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
for (Operation *op : ops) {
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
worklist.push_back(slice);
}

if (!patterns)
return success();

GreedyRewriteConfig config;
config.listener = this;
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, ExistingAndNewOps seems like it should be effectively same as AnyOp.

return applyOpPatternsAndFold(ops, patterns.value(), config);
}

void SliceTrackingListener::notifyOperationInserted(
Operation *op, OpBuilder::InsertPoint previous) {
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
if (!slice)
return;
worklist.push_back(slice);
}

// Scan the worklist for the given op and remove it if present. The expectation
// is for the worklist to be small and for removal to be relatively rare.
void SliceTrackingListener::removeOp(Operation *op) {
if (!isa<tensor::ExtractSliceOp>(op))
return;
auto iter = worklist.begin();
while (iter != worklist.end()) {
if (*iter == op)
break;
iter++;
}
if (iter == worklist.end())
return;

worklist.erase(iter);
}

void SliceTrackingListener::notifyOperationErased(Operation *op) {
removeOp(op);
}

void SliceTrackingListener::notifyOperationReplaced(Operation *op,
ValueRange replacement) {
removeOp(op);
}
} // namespace

/// Implementation of tile consumer and fuse producer greedily.
FailureOr<scf::SCFTileAndFuseResult>
mlir::scf::tileConsumerAndFuseProducersUsingSCF(
Expand Down Expand Up @@ -1370,33 +1470,32 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
tensor::ExtractSliceOp candidateSlice;
SCFTileAndFuseOptions::ControlFnResult controlFnResult;
};
std::deque<WorklistItem> worklist;
auto addCandidateSlices = [&worklist, &options,
&loops](ArrayRef<Operation *> candidates) {
for (auto candidate : candidates) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
if (!sliceOp || sliceOp.use_empty())
continue;

auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
if (!fusableProducer)
continue;
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
options.fusionControlFn(sliceOp, fusableProducer,
destinationInitArg.has_value());
if (!controlFnResult)
continue;
worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
}
};
SliceTrackingListener sliceTracker =
SliceTrackingListener(options.cleanupPatterns);

addCandidateSlices(tilingResult->generatedSlices);
if (failed(
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
}
OpBuilder::InsertionGuard g(rewriter);
while (!worklist.empty()) {
// Traverse the slices in BFS fashion.
WorklistItem worklistItem = worklist.front();
worklist.pop_front();
while (!sliceTracker.worklist.empty()) {
auto candidateSlice = sliceTracker.worklist.front();
sliceTracker.worklist.pop_front();

auto [fusableProducer, destinationInitArg] =
getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
loops);
if (!fusableProducer)
continue;

std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
options.fusionControlFn(candidateSlice, fusableProducer,
destinationInitArg.has_value());
if (!controlFnResult)
continue;

WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};

// The operands of the fused producer might themselved be slices of
// values produced by operations that implement the `TilingInterface`.
Expand All @@ -1407,6 +1506,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
if (!fusedResult)
continue;

SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;

if (worklistItem.controlFnResult.yieldProducerReplacement) {
// Reconstruct and yield all opResult of fusableProducerOp by default. The
// caller can specific which one to yield by designating optional argument
Expand All @@ -1421,20 +1522,23 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
fusableProducerOp, "failed to replacement value for this "
"operation from within the tiled loop");
}
addCandidateSlices(newSlices.value());
worklistCandidates.append(newSlices.value());
for (auto [index, result] :
llvm::enumerate(fusableProducerOp->getResults())) {
origValToResultNumber[result] = loops.front()->getNumResults() -
fusableProducerOp->getNumResults() +
index;
}
}
addCandidateSlices(fusedResult->generatedSlices);
if (Operation *tiledAndFusedOp =
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
tiledAndFusedOps.insert(tiledAndFusedOp);
}

if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
}
}

DenseMap<Value, Value> replacements;
Expand Down
100 changes: 100 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-fuse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,103 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @fuse_through_slice
func.func @fuse_through_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {

// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
// CHECK: linalg.elemwise_unary
// CHECK: linalg.elemwise_binary
// CHECK: return %[[RES]]
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %2 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @fuse_through_slice_and_cast_chain
func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {

// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
// CHECK: linalg.elemwise_unary
// CHECK: linalg.elemwise_binary
// CHECK: return %[[RES]]
%0 = linalg.elemwise_unary ins(%arg0 : tensor<100x100xf32>)
outs(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32>
%1 = tensor.cast %0 : tensor<100x100xf32> to tensor<100x?xf32>
%2 = tensor.extract_slice %1 [1, 1] [98, 98] [1, 1] : tensor<100x?xf32> to tensor<98x98xf32>
%3 = tensor.cast %2 : tensor<98x98xf32> to tensor<?x?xf32>
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%4 = tensor.extract_slice %3 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%5 = linalg.elemwise_binary ins(%4, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %5 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.yield
}
}

// -----

// CHECK-LABEL: func.func @fuse_unrelated_slice
func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<10x10xf32>) {

// CHECK: %[[SLICE1:.+]] = tensor.extract_slice
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]]
// CHECK: %[[RES:.*]] = scf.for
// CHECK: scf.for
// CHECK: linalg.elemwise_unary
// CHECK: linalg.elemwise_binary
// CHECK: return %[[RES]], %[[SLICE2]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor<?x?xf32> to tensor<10x10xf32>
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
%2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
return %2, %slice2 : tensor<?x?xf32>, tensor<10x10xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
transform.yield
}
}
Loading