Skip to content

Commit a461e0a

Browse files
[mlir][TilingInterface] NFC code changes separated out from introduction of scf::tileUsingSCFForallop.
This patch contains NFC changes that are precursor to the introduction of `scf::tileUsingSCFForallOp` method.
1 parent d13da15 commit a461e0a

File tree

8 files changed

+148
-127
lines changed

8 files changed

+148
-127
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct SCFTilingResult {
6060
/// of the last op.
6161
SmallVector<Operation *> tiledOps;
6262
/// The `scf.for` operations that iterate over the tiles.
63-
SmallVector<scf::ForOp> loops;
63+
SmallVector<Operation *> loops;
6464
/// Values to use as replacements for the untiled op. Is the same size as the
6565
/// number of results of the untiled op.
6666
SmallVector<Value> replacements;
@@ -160,7 +160,7 @@ struct SCFTileAndFuseResult {
160160
/// generated operation.
161161
llvm::SetVector<Operation *> tiledAndFusedOps;
162162
/// The `scf.for` operations that iterate over the tiles.
163-
SmallVector<scf::ForOp> loops;
163+
SmallVector<Operation *> loops;
164164
/// The replacement values to use for the tiled and fused operations.
165165
llvm::DenseMap<Value, Value> replacements;
166166
};

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -434,16 +434,12 @@ static LogicalResult applyTilingToAll(
434434
SmallVector<Operation *> opsToReplace{target};
435435
llvm::append_range(opsToReplace, tiledResults->fusedProducers);
436436
for (Operation *toReplace : opsToReplace) {
437-
SmallVector<Value> replacements;
438-
replacements.reserve(toReplace->getNumResults());
439-
for (OpResult res : toReplace->getResults()) {
440-
auto it = tiledResults->replacements.find(res);
441-
if (it == tiledResults->replacements.end())
442-
replacements.push_back(res);
443-
else
444-
replacements.push_back(it->getSecond());
437+
for (OpResult res : toReplace->getResults())
438+
if (auto replacement = tiledResults->replacements.lookup(res))
439+
rewriter.replaceAllUsesWith(res, replacement);
440+
if (toReplace->use_empty()) {
441+
rewriter.eraseOp(toReplace);
445442
}
446-
rewriter.replaceOp(toReplace, replacements);
447443
}
448444

449445
// Report back the relevant handles to the transform op.

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 70 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,30 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
5555
return filledVector;
5656
}
5757

58+
/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
59+
template <typename SrcOpTy>
60+
static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
61+
return llvm::to_vector(
62+
llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
63+
}
64+
template <typename SrcOpTy>
65+
static SmallVector<Operation *>
66+
getAsOperations(const SmallVector<SrcOpTy> &ops) {
67+
return getAsOperations(ArrayRef<SrcOpTy>(ops));
68+
}
69+
70+
/// Convert a list of `Operation *` to a list of `DstOpTy.
71+
template <typename DstOpTy>
72+
static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
73+
return llvm::to_vector(
74+
llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
75+
}
76+
template <typename DstOpTy>
77+
static SmallVector<DstOpTy>
78+
castToTypedOperations(const SmallVector<Operation *> &ops) {
79+
return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
80+
}
81+
5882
//===----------------------------------------------------------------------===//
5983
// tileUsingSCFForOp implementation.
6084
//===----------------------------------------------------------------------===//
@@ -77,10 +101,9 @@ static bool tileDividesIterationDomain(Range loopRange) {
77101
/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
78102
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
79103
Range loopRange, Value iv,
80-
Value tileSize) {
81-
std::optional<int64_t> ts = getConstantIntValue(tileSize);
82-
if (ts && ts.value() == 1)
83-
return getAsOpFoldResult(tileSize);
104+
OpFoldResult tileSize) {
105+
if (isConstantIntValue(tileSize, 1))
106+
return tileSize;
84107

85108
if (tileDividesIterationDomain(
86109
Range{loopRange.offset, loopRange.size, tileSize}))
@@ -295,8 +318,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
295318
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
296319
}
297320

298-
scf::SCFTilingResult tilingResult;
299321
SmallVector<OpFoldResult> offsets, sizes;
322+
SmallVector<scf::ForOp> forLoops;
300323
{
301324
// If there is an interchange specified, permute the iteration domain and
302325
// the tile sizes.
@@ -319,8 +342,8 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
319342
// 3. Materialize an empty loop nest that iterates over the tiles. These
320343
// loops for now do not return any values even if the original operation has
321344
// results.
322-
tilingResult.loops = generateTileLoopNest(
323-
rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes);
345+
forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
346+
tileSizeVector, offsets, sizes);
324347

325348
if (!interchangeVector.empty()) {
326349
auto inversePermutation = invertPermutationVector(interchangeVector);
@@ -330,30 +353,30 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
330353
}
331354

332355
LLVM_DEBUG({
333-
if (!tilingResult.loops.empty()) {
356+
if (!forLoops.empty()) {
334357
llvm::dbgs() << "LoopNest shell :\n";
335-
tilingResult.loops.front().dump();
358+
forLoops.front().dump();
336359
llvm::dbgs() << "\n";
337360
}
338361
});
339362

340363
// 4. Generate the tiled implementation within the inner most loop.
341-
if (!tilingResult.loops.empty())
342-
rewriter.setInsertionPoint(
343-
tilingResult.loops.back().getBody()->getTerminator());
364+
if (!forLoops.empty())
365+
rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
344366
FailureOr<TilingResult> tiledImplementation =
345367
op.getTiledImplementation(rewriter, offsets, sizes);
346-
tilingResult.tiledOps.append(tiledImplementation->tiledOps);
368+
347369
if (op->getNumResults() == 0) {
348-
// nothing more to do.
349-
return tilingResult;
370+
return scf::SCFTilingResult{
371+
tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
350372
}
351373

352374
// If loops are empty, the tiled op is used as the replacement for the untiled
353375
// op.
354-
if (tilingResult.loops.empty()) {
355-
tilingResult.replacements = tiledImplementation->tiledValues;
356-
return tilingResult;
376+
if (forLoops.empty()) {
377+
return scf::SCFTilingResult{tiledImplementation->tiledOps,
378+
getAsOperations(forLoops),
379+
tiledImplementation->tiledValues};
357380
}
358381

359382
// 5. Yield all the results of the tiled operation. The surrounding loop
@@ -377,18 +400,18 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
377400
destinationTensors)))
378401
return rewriter.notifyMatchFailure(op, "failed to get destinations");
379402

380-
tilingResult.replacements = yieldTiledValues(
403+
SmallVector<Value> replacements = yieldTiledValues(
381404
rewriter, destinationTensors, tiledImplementation.value(),
382-
resultOffsetsList, resultSizesList, tilingResult.loops);
383-
405+
resultOffsetsList, resultSizesList, forLoops);
384406
LLVM_DEBUG({
385-
if (!tilingResult.loops.empty()) {
407+
if (!forLoops.empty()) {
386408
llvm::dbgs() << "After tiled implementation :\n";
387-
tilingResult.loops.front().dump();
409+
forLoops.front().dump();
388410
llvm::dbgs() << "\n";
389411
}
390412
});
391-
return tilingResult;
413+
return scf::SCFTilingResult{tiledImplementation->tiledOps,
414+
getAsOperations(forLoops), replacements};
392415
}
393416

394417
FailureOr<scf::SCFReductionTilingResult>
@@ -466,6 +489,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
466489
results.mergeOp = mergeOp;
467490
return results;
468491
}
492+
469493
//===----------------------------------------------------------------------===//
470494
// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
471495
//===----------------------------------------------------------------------===//
@@ -636,28 +660,31 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
636660
}
637661

638662
// 1. First tile the consumer.
639-
scf::SCFTileAndFuseResult tileAndFuseResult;
663+
SmallVector<scf::ForOp> forLoops;
664+
SetVector<Operation *> fusedProducers, tiledAndFusedOps;
665+
DenseMap<Value, Value> replacements;
640666
llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber;
641667
{
642668
FailureOr<scf::SCFTilingResult> tilingResult =
643669
tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
644670
if (failed(tilingResult))
645671
return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
646672
for (auto *tiledOp : tilingResult->tiledOps)
647-
tileAndFuseResult.tiledAndFusedOps.insert(tiledOp);
648-
tileAndFuseResult.loops = std::move(tilingResult->loops);
649-
for (const auto &result : llvm::enumerate(
650-
llvm::zip(consumer->getResults(), tilingResult->replacements))) {
651-
tileAndFuseResult.replacements[std::get<0>(result.value())] =
652-
std::get<1>(result.value());
673+
tiledAndFusedOps.insert(tiledOp);
674+
forLoops = castToTypedOperations<scf::ForOp>(tilingResult->loops);
675+
for (auto [index, origValue, replacement] :
676+
llvm::enumerate(consumer->getResults(), tilingResult->replacements)) {
677+
replacements[origValue] = replacement;
653678
yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult(
654-
result.index())] = result.index();
679+
index)] = index;
655680
}
656681
}
657682

658683
// If there are no loops generated, fusion is immaterial.
659-
if (tileAndFuseResult.loops.empty())
660-
return tileAndFuseResult;
684+
if (forLoops.empty()) {
685+
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
686+
getAsOperations(forLoops), replacements};
687+
}
661688

662689
// 2. Typically, the operands of the tiled operation are slices of the
663690
// operands of the untiled operation. These are expressed in IR using
@@ -674,7 +701,7 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
674701
};
675702

676703
std::deque<tensor::ExtractSliceOp> candidates;
677-
addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates);
704+
addCandidateSlices(tiledAndFusedOps.back(), candidates);
678705
OpBuilder::InsertionGuard g(rewriter);
679706
while (!candidates.empty()) {
680707
// Traverse the slices in BFS fashion.
@@ -684,19 +711,20 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
684711
// The operands of the fused producer might themselved be slices of
685712
// values produced by operations that implement the `TilingInterface`.
686713
// Add these operations to the worklist.
687-
std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer =
688-
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp,
689-
tileAndFuseResult.loops);
690-
if (!fusedProducer)
714+
std::optional<scf::SCFFuseProducerOfSliceResult> fusedResult =
715+
tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, forLoops);
716+
if (!fusedResult)
691717
continue;
692718

693719
if (Operation *tiledAndFusedOp =
694-
fusedProducer->tiledAndFusedProducer.getDefiningOp()) {
695-
tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp);
720+
fusedResult->tiledAndFusedProducer.getDefiningOp()) {
721+
fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
722+
tiledAndFusedOps.insert(tiledAndFusedOp);
696723
addCandidateSlices(tiledAndFusedOp, candidates);
697724
}
698725
}
699-
return tileAndFuseResult;
726+
return scf::SCFTileAndFuseResult{fusedProducers, tiledAndFusedOps,
727+
getAsOperations(forLoops), replacements};
700728
}
701729

702730
//===----------------------------------------------------------------------===//

mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
88
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
99
%init = tensor.empty(%d0, %d1) : tensor<?x?xf32>
1010
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
11-
%gemm = linalg.matmul {__internal_linalg_transform__ = "fusion"}
11+
%gemm = linalg.matmul {__internal_transform__ = "fusion"}
1212
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
1313
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
1414
return %gemm : tensor<?x?xf32>
@@ -47,7 +47,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
4747
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
4848
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
4949
%generic = linalg.generic {
50-
__internal_linalg_transform__ = "fusion",
50+
__internal_transform__ = "fusion",
5151
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>],
5252
iterator_types = ["parallel", "parallel"]}
5353
ins(%gemm, %arg2 : tensor<?x?xf32>, tensor<?xf32>) outs(%init : tensor<?x?xf32>) {
@@ -97,7 +97,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %r
9797
%d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
9898
%init1 = tensor.empty(%d0, %d2) : tensor<?x?xf32>
9999
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
100-
%gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"}
100+
%gemm1 = linalg.matmul {__internal_transform__ = "gemm_fusion"}
101101
ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
102102
return %gemm1 : tensor<?x?xf32>
103103
}
@@ -147,7 +147,7 @@ func.func @gemm_transpose_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32
147147
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
148148
%init1 = tensor.empty(%d1, %d0) : tensor<?x?xf32>
149149
%transpose = linalg.generic {
150-
__internal_linalg_transform__ = "fusion",
150+
__internal_transform__ = "fusion",
151151
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>],
152152
iterator_types = ["parallel", "parallel"]}
153153
ins(%gemm : tensor<?x?xf32>) outs(%init1 : tensor<?x?xf32>) {
@@ -198,7 +198,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
198198
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
199199
outs(%1 : tensor<?x?xf32>) -> tensor<?x?xf32>
200200
%3 = linalg.generic {
201-
__internal_linalg_transform__ = "gemm_interchange_fusion",
201+
__internal_transform__ = "gemm_interchange_fusion",
202202
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
203203
iterator_types = ["parallel", "parallel"]}
204204
ins(%2 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) {
@@ -249,7 +249,7 @@ func.func @matmul_plus_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>,
249249
affine_map<(d0, d1) -> (d0, d1)>,
250250
affine_map<(d0, d1) -> (d0, d1)>],
251251
iterator_types = ["parallel", "parallel"],
252-
__internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
252+
__internal_transform__ = "gemm_plus_gemm_fusion"}
253253
ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
254254
outs(%5 : tensor<?x?xf32>) {
255255
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
@@ -302,7 +302,7 @@ func.func @matmul_plus_transpose_matmul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x
302302
affine_map<(d0, d1) -> (d1, d0)>,
303303
affine_map<(d0, d1) -> (d0, d1)>],
304304
iterator_types = ["parallel", "parallel"],
305-
__internal_linalg_transform__ = "gemm_plus_gemm_fusion"}
305+
__internal_transform__ = "gemm_plus_gemm_fusion"}
306306
ins(%2, %2 : tensor<?x?xf32>, tensor<?x?xf32>)
307307
outs(%5 : tensor<?x?xf32>) {
308308
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) :
@@ -352,7 +352,7 @@ func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
352352
%1 = linalg.matmul ins(%0, %arg3 : tensor<?x?xf32>, tensor<?x?xf32>)
353353
outs(%arg4 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N1] * [N1, N2]
354354
%2 = linalg.matmul
355-
{__internal_linalg_transform__ = "gemm_sequence_fusion"}
355+
{__internal_transform__ = "gemm_sequence_fusion"}
356356
ins(%1, %arg5 : tensor<?x?xf32>, tensor<?x?xf32>)
357357
outs(%arg6 : tensor<?x?xf32>) -> tensor<?x?xf32> // [M, N2] * [N2, N3]
358358
return %2 : tensor<?x?xf32>
@@ -425,7 +425,7 @@ func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> {
425425
linalg.yield %10, %9 : f32, f32
426426
} -> (tensor<30xf32>, tensor<30x3xf32>)
427427
%6 = linalg.generic {
428-
__internal_linalg_transform__ = "reduction_sequence_fusion",
428+
__internal_transform__ = "reduction_sequence_fusion",
429429
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>,
430430
affine_map<(d0, d1) -> (d0, d1)>],
431431
iterator_types = ["parallel", "parallel"]}

mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?
1313
ins(%lhs0, %rhs0 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill0 : tensor<?x?xf32>) -> tensor<?x?xf32>
1414
%d2 = tensor.dim %rhs1, %c1 : tensor<?x?xf32>
1515
%fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<?x?xf32>) -> tensor<?x?xf32>
16-
%gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_sequence_fusion_and_yield"}
16+
%gemm1 = linalg.matmul {__internal_transform__ = "gemm_sequence_fusion_and_yield"}
1717
ins(%gemm0, %rhs1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%fill1 : tensor<?x?xf32>) -> tensor<?x?xf32>
1818
return %gemm0, %gemm1 : tensor<?x?xf32>, tensor<?x?xf32>
1919
}

0 commit comments

Comments
 (0)