Skip to content

Commit f2b2872

Browse files
[mlir][SCF] Allow tiling by specifying maximum number of tiles.
1 parent 3cc288a commit f2b2872

File tree

8 files changed

+272
-303
lines changed

8 files changed

+272
-303
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class GenericOp;
3030
class LinalgOp;
3131
} // namespace linalg
3232

33+
namespace scf {
34+
struct SCFTilingResult;
35+
} // namespace scf
36+
3337
namespace tensor {
3438
class InsertSliceOp;
3539
class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
6064
ArrayRef<OpFoldResult> mixedNumThreads,
6165
ArrayRef<OpFoldResult> mixedTileSizes,
6266
std::optional<ArrayAttr> mapping,
63-
linalg::ForallTilingResult &tilingResult);
67+
scf::SCFTilingResult &tilingResult);
6468

6569
} // namespace transform
6670
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -866,31 +866,8 @@ FailureOr<ContinuousTileSizeSpecification>
866866
computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
867867
unsigned dimension, OpFoldResult targetSize,
868868
bool emitAssertions);
869-
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
870-
/// tiling by `numThreads`.
871-
/// If non-empty, the `mapping` is added as an attribute to the
872-
/// resulting `scf.forall`.
873-
/// Zero tile sizes indicate that the dimension is not tiled, and can be
874-
/// thought of as tiling by the full size of data. It is the user's
875-
/// responsibility to ensure that `numThreads` is a valid tiling specification
876-
/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
877-
struct ForallTilingResult {
878-
Operation *tileOp;
879-
Operation *tiledOp;
880-
};
881-
FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
882-
TilingInterface op,
883-
ArrayRef<OpFoldResult> numThreads,
884-
std::optional<ArrayAttr> mapping);
885-
886-
/// Same as `tileToForallOp`, but calculate the number of threads
887-
/// required using the given tileSizes.
888-
FailureOr<ForallTilingResult>
889-
tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
890-
ArrayRef<OpFoldResult> tileSizes,
891-
std::optional<ArrayAttr> mapping);
892-
893-
/// Transformation information returned after reduction tiling.
869+
870+
/// Transformation information returned after reduction tiling.
894871
struct ForallReductionTilingResult {
895872
/// The partial reduction tiled op generated.
896873
SmallVector<Operation *> parallelTiledOps;

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@ using SCFTileSizeComputationFunction =
3232

3333
/// Options to use to control tiling.
3434
struct SCFTilingOptions {
35-
/// Computation function that returns the tile sizes for each operation.
36-
/// Delayed construction of constant tile sizes should occur to interoperate
37-
/// with folding.
35+
/// Computation function that returns the tile sizes to use for each loop.
36+
/// Returning a tile size of zero implies no tiling for that loop. If the
37+
/// size of the returned vector is smaller than the number of loops, the inner
38+
/// loops are not tiled. If the size of the returned vector is larger, then
39+
/// the vector is truncated to number of loops. Only one of
40+
/// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
41+
/// be used.
3842
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
3943

4044
SCFTilingOptions &
@@ -45,7 +49,25 @@ struct SCFTilingOptions {
4549
/// Convenience function to set the `tileSizeComputationFunction` to a
4650
/// function that computes tile sizes at the point they are needed. Allows
4751
/// proper interaction with folding.
48-
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
52+
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
53+
54+
/// Computation function that returns the maximum number of tile to use for
55+
/// each loop. Returning a tile size of zero implies no tiling for that loop.
56+
/// If the size of the returned vector is smaller than the number of loops,
57+
/// the inner loops are not tiled. If the size of the returned vector is
58+
/// larger, then the vector is truncated to number of loops. Only one of
59+
/// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
60+
/// be used.
61+
SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
62+
63+
SCFTilingOptions &
64+
setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
65+
maxNumTilesComputationFunction = std::move(fun);
66+
return *this;
67+
}
68+
/// Convenience function to set the `tileSizeComputationFunction` to a
69+
/// function that computes tile sizes at the point they are needed.
70+
SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
4971

5072
/// The interchange vector to reorder the tiled loops.
5173
SmallVector<int64_t> interchangeVector = {};
@@ -67,9 +89,8 @@ struct SCFTilingOptions {
6789
/// when using loop constructs that dont support such a mapping (like
6890
/// `scf.for`)
6991
SmallVector<Attribute> mappingVector = {};
70-
SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
71-
mappingVector = llvm::map_to_vector(
72-
mapping, [](auto attr) -> Attribute { return attr; });
92+
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
93+
mappingVector = llvm::to_vector(mapping);
7394
return *this;
7495
}
7596
};

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3156,7 +3156,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
31563156
TransformOpInterface transformOp, Operation *target,
31573157
ArrayRef<OpFoldResult> mixedNumThreads,
31583158
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
3159-
linalg::ForallTilingResult &tilingResult) {
3159+
scf::SCFTilingResult &tilingResult) {
31603160
// Transform all targets one by one.
31613161
auto tileableOp = dyn_cast<TilingInterface>(target);
31623162
if (!tileableOp) {
@@ -3167,18 +3167,38 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
31673167
return diag;
31683168
}
31693169
rewriter.setInsertionPoint(tileableOp);
3170-
FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
3170+
scf::SCFTilingOptions options;
3171+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
31713172
if (!mixedNumThreads.empty()) {
3172-
maybeTilingResult =
3173-
linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
3173+
options.setMaxNumTiles(mixedNumThreads);
31743174
} else {
3175-
maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
3176-
rewriter, tileableOp, mixedTileSizes, mapping);
3175+
SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
3176+
unsigned nLoops = loopRanges.size();
3177+
SmallVector<OpFoldResult> numThreads;
3178+
numThreads.reserve(nLoops);
3179+
AffineExpr s0, s1;
3180+
bindSymbols(rewriter.getContext(), s0, s1);
3181+
AffineExpr divExpr = s0.ceilDiv(s1);
3182+
for (int i = 0, e = std::min(mixedTileSizes.size(), loopRanges.size());
3183+
i < e; ++i) {
3184+
OpFoldResult numTiles = mixedTileSizes[i];
3185+
if (!isConstantIntValue(numTiles, 0))
3186+
numTiles = affine::makeComposedFoldedAffineApply(
3187+
rewriter, tileableOp.getLoc(), divExpr,
3188+
{loopRanges[i].size, numTiles});
3189+
numThreads.push_back(numTiles);
3190+
}
3191+
options.setMaxNumTiles(numThreads);
3192+
}
3193+
if (mapping) {
3194+
options.setMapping(mapping.value().getValue());
31773195
}
3196+
FailureOr<scf::SCFTilingResult> maybeTilingResult =
3197+
scf::tileUsingSCF(rewriter, tileableOp, options);
31783198

31793199
if (failed(maybeTilingResult))
31803200
return transformOp.emitDefaultSilenceableFailure(tileableOp);
3181-
rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
3201+
rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
31823202

31833203
tilingResult = *maybeTilingResult;
31843204
return DiagnosedSilenceableFailure::success();
@@ -3214,14 +3234,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
32143234
return status;
32153235

32163236
for (Operation *target : state.getPayloadOps(getTarget())) {
3217-
linalg::ForallTilingResult tilingResult;
3237+
scf::SCFTilingResult tilingResult;
32183238
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
32193239
rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
32203240
getMapping(), tilingResult);
32213241
if (!diag.succeeded())
32223242
return diag;
3223-
tileOps.push_back(tilingResult.tileOp);
3224-
tiledOps.push_back(tilingResult.tiledOp);
3243+
tileOps.push_back(tilingResult.loops.front());
3244+
tiledOps.append(tilingResult.tiledOps);
32253245
}
32263246

32273247
transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3699,7 +3719,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
36993719

37003720
// OpBuilder only used to compute attributes.
37013721
OpBuilder b(getContext());
3702-
linalg::ForallTilingResult tilingResult;
3722+
scf::SCFTilingResult tilingResult;
37033723
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
37043724
/*rewriter=*/rewriter,
37053725
/*state=*/state,
@@ -3712,8 +3732,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
37123732
if (!diag.succeeded())
37133733
return diag;
37143734

3715-
results.push_back(tilingResult.tileOp);
3716-
results.push_back(tilingResult.tiledOp);
3735+
results.push_back(tilingResult.loops.front());
3736+
for (auto op : tilingResult.tiledOps)
3737+
results.push_back(op);
37173738
return DiagnosedSilenceableFailure::success();
37183739
}
37193740

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 0 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -435,188 +435,6 @@ static void calculateTileOffsetsAndSizes(
435435
}
436436
}
437437

438-
/// Returns a vector of bools representing if, for each axis, `op` can be tiled
439-
/// without incurring in a race condition and thus it is thread-safe to do the
440-
/// tiling. This is checked by iterating over numThreads and ensuring that the
441-
/// corresponding iterator type is "parallel". If it is not, then we know that
442-
/// such dimension is unsafe to tile.
443-
SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
444-
ArrayRef<OpFoldResult> numThreads) {
445-
auto iterators = linalgOp.getIteratorTypesArray();
446-
SmallVector<bool> safeToTile(numThreads.size(), true);
447-
448-
for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
449-
if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
450-
if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
451-
safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
452-
}
453-
} else {
454-
safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
455-
}
456-
}
457-
return safeToTile;
458-
}
459-
460-
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
461-
/// tiling is specified by the number of tiles/threads `numThreads` and the
462-
/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
463-
/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
464-
/// numThreads[i])`. If non-empty, the `mapping` is added as an
465-
/// attribute to the resulting `scf.forall`. A zero tile sizes indicate
466-
/// that the dimension is not tiled, and can be thought of as tiling by the full
467-
/// size of data.
468-
/// It is the user's responsibility to ensure that `numThreads` is a valid
469-
/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
470-
/// Linalg case). If the dimension is not parallelizable, a warning is issued to
471-
/// notify the user that the generated code is not safe to parallelize. If
472-
/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
473-
/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
474-
static FailureOr<ForallTilingResult> tileToForallOpImpl(
475-
RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
476-
std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
477-
std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
478-
Location loc = op->getLoc();
479-
OpBuilder::InsertionGuard g(b);
480-
481-
SmallVector<Range> loopRanges = op.getIterationDomain(b);
482-
if (loopRanges.empty())
483-
return op->emitOpError("expected non-empty loop ranges");
484-
auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
485-
if (llvm::any_of(loopRanges, hasStrideOne))
486-
return op->emitOpError("only stride-1 supported atm");
487-
488-
// Gather destination tensors.
489-
SmallVector<Value> dest;
490-
if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
491-
return op->emitOpError("failed to get destination tensors");
492-
493-
SmallVector<OpFoldResult> nonZeroNumThreads =
494-
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
495-
return !isConstantIntValue(ofr, 0);
496-
}));
497-
SmallVector<Value> materializedNonZeroNumThreads =
498-
llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
499-
return getValueOrCreateConstantIndexOp(b, loc, ofr);
500-
}));
501-
502-
LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
503-
if (linalgOp) {
504-
// Check if tiling is thread safe and print a warning if not.
505-
SmallVector<bool> tilingSafety =
506-
safeToTileToForall(b.getContext(), linalgOp, numThreads);
507-
for (size_t i = 0; i < tilingSafety.size(); i++)
508-
if (!tilingSafety[i])
509-
op.emitWarning() << "tiling is not thread safe at axis #" << i;
510-
}
511-
512-
// 1. Create the ForallOp. We don't use the lambda body-builder
513-
// version because we require the use of RewriterBase in the body, so we
514-
// manually move the insertion point to the body below.
515-
scf::ForallOp forallOp = b.create<scf::ForallOp>(
516-
loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
517-
518-
// 2. Fill out the ForallOp body.
519-
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
520-
calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
521-
omitTileOffsetBoundsCheck, nominalTileSizes,
522-
tiledOffsets, tiledSizes);
523-
524-
// 3. Clone the tileable op and update its destination operands to use the
525-
// output bbArgs of the ForallOp.
526-
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
527-
Operation *tiledOp = nullptr;
528-
SmallVector<Value> tiledValues;
529-
{
530-
// 3.a. RAII guard, inserting within forallOp, before terminator.
531-
OpBuilder::InsertionGuard g(b);
532-
b.setInsertionPoint(forallOp.getTerminator());
533-
Operation *clonedOp = b.clone(*op.getOperation());
534-
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
535-
if (destinationStyleOp) {
536-
for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
537-
// Swap tensor inits with the corresponding block argument of the
538-
// scf.forall op. Memref inits remain as is.
539-
if (isa<TensorType>(outOperand.get().getType())) {
540-
auto *it = llvm::find(dest, outOperand.get());
541-
assert(it != dest.end() && "could not find destination tensor");
542-
unsigned destNum = std::distance(dest.begin(), it);
543-
outOperand.set(destBbArgs[destNum]);
544-
}
545-
}
546-
}
547-
548-
// 4. Tile the cloned op and delete the clone.
549-
FailureOr<TilingResult> tilingResult =
550-
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
551-
tiledSizes);
552-
if (failed(tilingResult))
553-
return clonedOp->emitError("Failed to tile op: ");
554-
if (tilingResult->tiledOps.size() != 1) {
555-
return clonedOp->emitError("expected a single produced tiled op, got ")
556-
<< tilingResult->tiledOps.size();
557-
}
558-
559-
b.eraseOp(clonedOp);
560-
tiledOp = tilingResult->tiledOps.front();
561-
tiledValues = tilingResult->tiledValues;
562-
}
563-
564-
// 5. Parallel insert back into the result tensor.
565-
for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
566-
tiledValues, destBbArgs)) {
567-
// 5.a. Partial subset information is inserted just before the terminator.
568-
OpBuilder::InsertionGuard g(b);
569-
b.setInsertionPoint(forallOp.getTerminator());
570-
571-
SmallVector<OpFoldResult> resultOffsets, resultSizes;
572-
if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
573-
tiledSizes, resultOffsets,
574-
resultSizes)))
575-
return op->emitOpError("output offsets couldn't be calculated");
576-
SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
577-
578-
// 5.b. Parallel insertions are inserted at the end of the combining
579-
// terminator.
580-
b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
581-
b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
582-
std::get<2>(it), resultOffsets,
583-
resultSizes, strides);
584-
}
585-
return ForallTilingResult{forallOp, tiledOp};
586-
}
587-
588-
FailureOr<ForallTilingResult>
589-
linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
590-
ArrayRef<OpFoldResult> numThreads,
591-
std::optional<ArrayAttr> mapping) {
592-
return tileToForallOpImpl(b, op, numThreads,
593-
/*nominalTileSizes=*/std::nullopt, mapping,
594-
/*omitTileOffsetBoundsCheck=*/false);
595-
}
596-
597-
FailureOr<ForallTilingResult>
598-
linalg::tileToForallOpUsingTileSizes(RewriterBase &b, TilingInterface op,
599-
ArrayRef<OpFoldResult> tileSizes,
600-
std::optional<ArrayAttr> mapping) {
601-
SmallVector<Range> loopRanges = op.getIterationDomain(b);
602-
unsigned nLoops = loopRanges.size();
603-
SmallVector<OpFoldResult> numThreads;
604-
numThreads.reserve(nLoops);
605-
AffineExpr s0, s1;
606-
bindSymbols(b.getContext(), s0, s1);
607-
AffineExpr divExpr = s0.ceilDiv(s1);
608-
for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
609-
OpFoldResult numTiles = std::get<0>(it);
610-
if (!isConstantIntValue(numTiles, 0))
611-
numTiles = makeComposedFoldedAffineApply(
612-
b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
613-
numThreads.push_back(numTiles);
614-
}
615-
return tileToForallOpImpl(b, op, numThreads,
616-
/*nominalTileSizes=*/tileSizes, mapping,
617-
/*omitTileOffsetBoundsCheck=*/true);
618-
}
619-
620438
template <typename LoopTy>
621439
static FailureOr<TiledLinalgOp>
622440
tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,

0 commit comments

Comments
 (0)