Skip to content

Commit d93c9f9

Browse files
qedawkinsagozillon
authored andcommitted
[𝘀𝗽𝗿] changes to main this commit is based on
Created using spr 1.3.4 [skip ci]
1 parent f74879c commit d93c9f9

File tree

5 files changed

+252
-28
lines changed

5 files changed

+252
-28
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,18 +295,23 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
295295
let description = [{
296296
Tiles the operations pointed to by the target handle and fuses their
297297
producers greedily using the options provided as attributes.
298+
299+
If `apply_cleanup` is true then slice canonicalization is applied between
300+
fusion steps.
298301
}];
299302

300303
let arguments =
301304
(ins TransformHandleTypeInterface:$target,
302305
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
303-
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange);
306+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
307+
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
304308
let results = (outs TransformHandleTypeInterface:$transformed,
305309
Variadic<TransformHandleTypeInterface>:$loops);
306310

307311
let assemblyFormat = [{
308312
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
309-
attr-dict `:` functional-type(operands, results)
313+
(`apply_cleanup` `=` $apply_cleanup^)? attr-dict
314+
`:` functional-type(operands, results)
310315
}];
311316
let hasVerifier = 1;
312317
}

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Interfaces/LoopLikeInterface.h"
1616
#include "mlir/Interfaces/TilingInterface.h"
1717
#include "mlir/Interfaces/ViewLikeInterface.h"
18+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
1819

1920
#include <deque>
2021

@@ -153,6 +154,11 @@ struct SCFTileAndFuseOptions {
153154
fusionControlFn = controlFn;
154155
return *this;
155156
}
157+
158+
/// An optional set of rewrite patterns to apply to the results of tiling
159+
/// before fusion. This will track deleted and newly inserted
160+
/// `tensor.extract_slice` ops and update the worklist.
161+
std::optional<FrozenRewritePatternSet> cleanupPatterns = std::nullopt;
156162
};
157163

158164
/// Fuse the producer of the source of `candidateSliceOp` by computing the

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,15 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
562562
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
563563
scf::SCFTileAndFuseOptions tileAndFuseOptions;
564564
tileAndFuseOptions.tilingOptions = tilingOptions;
565+
566+
if (getApplyCleanup()) {
567+
MLIRContext *context = rewriter.getContext();
568+
RewritePatternSet patterns(context);
569+
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
570+
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
571+
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
572+
}
573+
565574
LogicalResult result = applyTilingToAll(
566575
rewriter, getOperation(), state.getPayloadOps(getTarget()),
567576
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 130 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "mlir/IR/PatternMatch.h"
2525
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2626
#include "mlir/Interfaces/TilingInterface.h"
27+
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
28+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2729
#include "llvm/ADT/TypeSwitch.h"
2830
#include "llvm/Support/Debug.h"
2931
#include <optional>
@@ -1315,6 +1317,104 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
13151317
return generatedSlices;
13161318
}
13171319

1320+
namespace {
1321+
1322+
//===----------------------------------------------------------------------===//
1323+
// SliceTrackingListener
1324+
//===----------------------------------------------------------------------===//
1325+
1326+
/// This class is a listener for tracking the insertion and removal of
1327+
/// `tensor.extract_slice` ops in a worklist. This can be used in a greedy
1328+
/// fusion algorithm to apply cleanup patterns in between fusion steps.
1329+
class SliceTrackingListener : public RewriterBase::Listener {
1330+
public:
1331+
explicit SliceTrackingListener(
1332+
std::optional<FrozenRewritePatternSet> patterns);
1333+
SliceTrackingListener() = default;
1334+
1335+
/// Adds the given list of operations to the worklist, and if present, applies
1336+
/// the list of `patterns` to the newly added operations. This only processes
1337+
/// the given operations and any newly inserted ones by the pattern set.
1338+
LogicalResult insertAndApplyPatterns(ArrayRef<Operation *> newOps);
1339+
1340+
/// Add to the new operation worklist if it is an extract_slice.
1341+
void notifyOperationInserted(Operation *op,
1342+
OpBuilder::InsertPoint previous) override;
1343+
1344+
/// Shared helper for operation removal from the worklist.
1345+
void removeOp(Operation *op);
1346+
1347+
/// Remove the operation from the worklist.
1348+
void notifyOperationErased(Operation *op) override;
1349+
1350+
/// Remove the operation from the worklist.
1351+
void notifyOperationReplaced(Operation *op, ValueRange replacement) override;
1352+
1353+
/// The worklist for this transformation keeps track of the slices to visit
1354+
/// next for fusion.
1355+
std::deque<tensor::ExtractSliceOp> worklist;
1356+
1357+
private:
1358+
/// Optional pattern set to apply when adding new operations to the worklist.
1359+
std::optional<FrozenRewritePatternSet> patterns = std::nullopt;
1360+
};
1361+
1362+
SliceTrackingListener::SliceTrackingListener(
1363+
std::optional<FrozenRewritePatternSet> p) {
1364+
patterns = std::move(p);
1365+
}
1366+
1367+
LogicalResult
1368+
SliceTrackingListener::insertAndApplyPatterns(ArrayRef<Operation *> ops) {
1369+
for (Operation *op : ops) {
1370+
if (auto slice = dyn_cast<tensor::ExtractSliceOp>(op))
1371+
worklist.push_back(slice);
1372+
}
1373+
1374+
if (!patterns)
1375+
return success();
1376+
1377+
GreedyRewriteConfig config;
1378+
config.listener = this;
1379+
config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps;
1380+
return applyOpPatternsAndFold(ops, patterns.value(), config);
1381+
}
1382+
1383+
void SliceTrackingListener::notifyOperationInserted(
1384+
Operation *op, OpBuilder::InsertPoint previous) {
1385+
auto slice = dyn_cast<tensor::ExtractSliceOp>(op);
1386+
if (!slice)
1387+
return;
1388+
worklist.push_back(slice);
1389+
}
1390+
1391+
// Scan the worklist for the given op and remove it if present. The expectation
1392+
// is for the worklist to be small and for removal to be relatively rare.
1393+
void SliceTrackingListener::removeOp(Operation *op) {
1394+
if (!isa<tensor::ExtractSliceOp>(op))
1395+
return;
1396+
auto iter = worklist.begin();
1397+
while (iter != worklist.end()) {
1398+
if (*iter == op)
1399+
break;
1400+
iter++;
1401+
}
1402+
if (iter == worklist.end())
1403+
return;
1404+
1405+
worklist.erase(iter);
1406+
}
1407+
1408+
void SliceTrackingListener::notifyOperationErased(Operation *op) {
1409+
removeOp(op);
1410+
}
1411+
1412+
void SliceTrackingListener::notifyOperationReplaced(Operation *op,
1413+
ValueRange replacement) {
1414+
removeOp(op);
1415+
}
1416+
} // namespace
1417+
13181418
/// Implementation of tile consumer and fuse producer greedily.
13191419
FailureOr<scf::SCFTileAndFuseResult>
13201420
mlir::scf::tileConsumerAndFuseProducersUsingSCF(
@@ -1370,33 +1470,32 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
13701470
tensor::ExtractSliceOp candidateSlice;
13711471
SCFTileAndFuseOptions::ControlFnResult controlFnResult;
13721472
};
1373-
std::deque<WorklistItem> worklist;
1374-
auto addCandidateSlices = [&worklist, &options,
1375-
&loops](ArrayRef<Operation *> candidates) {
1376-
for (auto candidate : candidates) {
1377-
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(candidate);
1378-
if (!sliceOp || sliceOp.use_empty())
1379-
continue;
13801473

1381-
auto [fusableProducer, destinationInitArg] =
1382-
getUntiledProducerFromSliceSource(&sliceOp.getSourceMutable(), loops);
1383-
if (!fusableProducer)
1384-
continue;
1385-
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1386-
options.fusionControlFn(sliceOp, fusableProducer,
1387-
destinationInitArg.has_value());
1388-
if (!controlFnResult)
1389-
continue;
1390-
worklist.emplace_back(WorklistItem{sliceOp, controlFnResult.value()});
1391-
}
1392-
};
1474+
SliceTrackingListener sliceTracker =
1475+
SliceTrackingListener(options.cleanupPatterns);
13931476

1394-
addCandidateSlices(tilingResult->generatedSlices);
1477+
if (failed(
1478+
sliceTracker.insertAndApplyPatterns(tilingResult->generatedSlices))) {
1479+
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1480+
}
13951481
OpBuilder::InsertionGuard g(rewriter);
1396-
while (!worklist.empty()) {
1397-
// Traverse the slices in BFS fashion.
1398-
WorklistItem worklistItem = worklist.front();
1399-
worklist.pop_front();
1482+
while (!sliceTracker.worklist.empty()) {
1483+
auto candidateSlice = sliceTracker.worklist.front();
1484+
sliceTracker.worklist.pop_front();
1485+
1486+
auto [fusableProducer, destinationInitArg] =
1487+
getUntiledProducerFromSliceSource(&candidateSlice.getSourceMutable(),
1488+
loops);
1489+
if (!fusableProducer)
1490+
continue;
1491+
1492+
std::optional<SCFTileAndFuseOptions::ControlFnResult> controlFnResult =
1493+
options.fusionControlFn(candidateSlice, fusableProducer,
1494+
destinationInitArg.has_value());
1495+
if (!controlFnResult)
1496+
continue;
1497+
1498+
WorklistItem worklistItem = {candidateSlice, controlFnResult.value()};
14001499

14011500
// The operands of the fused producer might themselved be slices of
14021501
// values produced by operations that implement the `TilingInterface`.
@@ -1407,6 +1506,8 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14071506
if (!fusedResult)
14081507
continue;
14091508

1509+
SmallVector<Operation *> worklistCandidates = fusedResult->generatedSlices;
1510+
14101511
if (worklistItem.controlFnResult.yieldProducerReplacement) {
14111512
// Reconstruct and yield all opResult of fusableProducerOp by default. The
14121513
// caller can specific which one to yield by designating optional argument
@@ -1421,20 +1522,23 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
14211522
fusableProducerOp, "failed to replacement value for this "
14221523
"operation from within the tiled loop");
14231524
}
1424-
addCandidateSlices(newSlices.value());
1525+
worklistCandidates.append(newSlices.value());
14251526
for (auto [index, result] :
14261527
llvm::enumerate(fusableProducerOp->getResults())) {
14271528
origValToResultNumber[result] = loops.front()->getNumResults() -
14281529
fusableProducerOp->getNumResults() +
14291530
index;
14301531
}
14311532
}
1432-
addCandidateSlices(fusedResult->generatedSlices);
14331533
if (Operation *tiledAndFusedOp =
14341534
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
14351535
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
14361536
tiledAndFusedOps.insert(tiledAndFusedOp);
14371537
}
1538+
1539+
if (failed(sliceTracker.insertAndApplyPatterns(worklistCandidates))) {
1540+
return rewriter.notifyMatchFailure(consumer, "cleanup patterns failed");
1541+
}
14381542
}
14391543

14401544
DenseMap<Value, Value> replacements;

mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,103 @@ module attributes {transform.with_named_sequence} {
178178
transform.yield
179179
}
180180
}
181+
182+
// -----
183+
184+
// CHECK-LABEL: func.func @fuse_through_slice
185+
func.func @fuse_through_slice(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
186+
187+
// CHECK: %[[RES:.*]] = scf.for
188+
// CHECK: scf.for
189+
// CHECK: linalg.elemwise_unary
190+
// CHECK: linalg.elemwise_binary
191+
// CHECK: return %[[RES]]
192+
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
193+
outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
194+
%c0 = arith.constant 0 : index
195+
%c1 = arith.constant 1 : index
196+
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
197+
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
198+
%1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
199+
%2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
200+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
201+
return %2 : tensor<?x?xf32>
202+
}
203+
204+
module attributes {transform.with_named_sequence} {
205+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
206+
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
207+
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
208+
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
209+
transform.yield
210+
}
211+
}
212+
213+
// -----
214+
215+
// CHECK-LABEL: func.func @fuse_through_slice_and_cast_chain
216+
func.func @fuse_through_slice_and_cast_chain(%arg0: tensor<100x100xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
217+
218+
// CHECK: %[[RES:.*]] = scf.for
219+
// CHECK: scf.for
220+
// CHECK: linalg.elemwise_unary
221+
// CHECK: linalg.elemwise_binary
222+
// CHECK: return %[[RES]]
223+
%0 = linalg.elemwise_unary ins(%arg0 : tensor<100x100xf32>)
224+
outs(%arg0: tensor<100x100xf32>) -> tensor<100x100xf32>
225+
%1 = tensor.cast %0 : tensor<100x100xf32> to tensor<100x?xf32>
226+
%2 = tensor.extract_slice %1 [1, 1] [98, 98] [1, 1] : tensor<100x?xf32> to tensor<98x98xf32>
227+
%3 = tensor.cast %2 : tensor<98x98xf32> to tensor<?x?xf32>
228+
%c0 = arith.constant 0 : index
229+
%c1 = arith.constant 1 : index
230+
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
231+
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
232+
%4 = tensor.extract_slice %3 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
233+
%5 = linalg.elemwise_binary ins(%4, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
234+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
235+
return %5 : tensor<?x?xf32>
236+
}
237+
238+
module attributes {transform.with_named_sequence} {
239+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
240+
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
241+
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
242+
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
243+
transform.yield
244+
}
245+
}
246+
247+
// -----
248+
249+
// CHECK-LABEL: func.func @fuse_unrelated_slice
250+
func.func @fuse_unrelated_slices(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<10x10xf32>) {
251+
252+
// CHECK: %[[SLICE1:.+]] = tensor.extract_slice
253+
// CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[SLICE1]]
254+
// CHECK: %[[RES:.*]] = scf.for
255+
// CHECK: scf.for
256+
// CHECK: linalg.elemwise_unary
257+
// CHECK: linalg.elemwise_binary
258+
// CHECK: return %[[RES]], %[[SLICE2]]
259+
%c0 = arith.constant 0 : index
260+
%c1 = arith.constant 1 : index
261+
%dim0 = tensor.dim %arg1, %c0 : tensor<?x?xf32>
262+
%dim1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
263+
%slice1 = tensor.extract_slice %arg0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
264+
%slice2 = tensor.extract_slice %slice1 [1, 1] [10, 10] [1, 1] : tensor<?x?xf32> to tensor<10x10xf32>
265+
%0 = linalg.elemwise_unary ins(%arg0 : tensor<?x?xf32>)
266+
outs(%arg0: tensor<?x?xf32>) -> tensor<?x?xf32>
267+
%1 = tensor.extract_slice %0 [1, 1] [%dim0, %dim1] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
268+
%2 = linalg.elemwise_binary ins(%1, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
269+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
270+
return %2, %slice2 : tensor<?x?xf32>, tensor<10x10xf32>
271+
}
272+
273+
module attributes {transform.with_named_sequence} {
274+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
275+
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1 : (!transform.any_op) -> !transform.any_op
276+
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1], apply_cleanup = true}
277+
: (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
278+
transform.yield
279+
}
280+
}

0 commit comments

Comments
 (0)