Skip to content

Commit 4151c40

Browse files
[mlir][SCF] Allow tiling by specifying maximum number of tiles.
1 parent 4ab37e4 commit 4151c40

File tree

8 files changed

+270
-302
lines changed

8 files changed

+270
-302
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class GenericOp;
3030
class LinalgOp;
3131
} // namespace linalg
3232

33+
namespace scf {
34+
struct SCFTilingResult;
35+
} // namespace scf
36+
3337
namespace tensor {
3438
class InsertSliceOp;
3539
class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
6064
ArrayRef<OpFoldResult> mixedNumThreads,
6165
ArrayRef<OpFoldResult> mixedTileSizes,
6266
std::optional<ArrayAttr> mapping,
63-
linalg::ForallTilingResult &tilingResult);
67+
scf::SCFTilingResult &tilingResult);
6468

6569
} // namespace transform
6670
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -846,30 +846,6 @@ FailureOr<StaticMultiSizeSpecification>
846846
computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
847847
int64_t divisor);
848848

849-
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
850-
/// tiling by `numThreads`.
851-
/// If non-empty, the `mapping` is added as an attribute to the
852-
/// resulting `scf.forall`.
853-
/// Zero tile sizes indicate that the dimension is not tiled, and can be
854-
/// thought of as tiling by the full size of data. It is the user's
855-
/// responsibility to ensure that `numThreads` is a valid tiling specification
856-
/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
857-
struct ForallTilingResult {
858-
Operation *tileOp;
859-
Operation *tiledOp;
860-
};
861-
FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
862-
TilingInterface op,
863-
ArrayRef<OpFoldResult> numThreads,
864-
std::optional<ArrayAttr> mapping);
865-
866-
/// Same as `tileToForallOp`, but calculate the number of threads
867-
/// required using the given tileSizes.
868-
FailureOr<ForallTilingResult>
869-
tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
870-
ArrayRef<OpFoldResult> tileSizes,
871-
std::optional<ArrayAttr> mapping);
872-
873849
/// Transformation information returned after reduction tiling.
874850
struct ForallReductionTilingResult {
875851
/// The partial reduction tiled op generated.

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,13 @@ using SCFTileSizeComputationFunction =
3232

3333
/// Options to use to control tiling.
3434
struct SCFTilingOptions {
35-
/// Computation function that returns the tile sizes for each operation.
36-
/// Delayed construction of constant tile sizes should occur to interoperate
37-
/// with folding.
35+
/// Computation function that returns the tile sizes to use for each loop.
36+
/// Returning a tile size of zero implies no tiling for that loop. If the
37+
/// size of the returned vector is smaller than the number of loops, the inner
38+
/// loops are not tiled. If the size of the returned vector is larger, then
39+
/// the vector is truncated to number of loops. Only one of
40+
/// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
41+
/// be used.
3842
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
3943

4044
SCFTilingOptions &
@@ -45,7 +49,25 @@ struct SCFTilingOptions {
4549
/// Convenience function to set the `tileSizeComputationFunction` to a
4650
/// function that computes tile sizes at the point they are needed. Allows
4751
/// proper interaction with folding.
48-
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
52+
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
53+
54+
/// Computation function that returns the maximum number of tile to use for
55+
/// each loop. Returning a tile size of zero implies no tiling for that loop.
56+
/// If the size of the returned vector is smaller than the number of loops,
57+
/// the inner loops are not tiled. If the size of the returned vector is
58+
/// larger, then the vector is truncated to number of loops. Only one of
59+
/// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
60+
/// be used.
61+
SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
62+
63+
SCFTilingOptions &
64+
setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
65+
maxNumTilesComputationFunction = std::move(fun);
66+
return *this;
67+
}
68+
/// Convenience function to set the `tileSizeComputationFunction` to a
69+
/// function that computes tile sizes at the point they are needed.
70+
SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
4971

5072
/// The interchange vector to reorder the tiled loops.
5173
SmallVector<int64_t> interchangeVector = {};
@@ -67,9 +89,8 @@ struct SCFTilingOptions {
6789
/// when using loop constructs that dont support such a mapping (like
6890
/// `scf.for`)
6991
SmallVector<Attribute> mappingVector = {};
70-
SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
71-
mappingVector = llvm::map_to_vector(
72-
mapping, [](auto attr) -> Attribute { return attr; });
92+
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
93+
mappingVector = llvm::to_vector(mapping);
7394
return *this;
7495
}
7596
};

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2919,7 +2919,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
29192919
TransformOpInterface transformOp, Operation *target,
29202920
ArrayRef<OpFoldResult> mixedNumThreads,
29212921
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
2922-
linalg::ForallTilingResult &tilingResult) {
2922+
scf::SCFTilingResult &tilingResult) {
29232923
// Transform all targets one by one.
29242924
auto tileableOp = dyn_cast<TilingInterface>(target);
29252925
if (!tileableOp) {
@@ -2930,18 +2930,38 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
29302930
return diag;
29312931
}
29322932
rewriter.setInsertionPoint(tileableOp);
2933-
FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
2933+
scf::SCFTilingOptions options;
2934+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
29342935
if (!mixedNumThreads.empty()) {
2935-
maybeTilingResult =
2936-
linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
2936+
options.setMaxNumTiles(mixedNumThreads);
29372937
} else {
2938-
maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
2939-
rewriter, tileableOp, mixedTileSizes, mapping);
2938+
SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
2939+
unsigned nLoops = loopRanges.size();
2940+
SmallVector<OpFoldResult> numThreads;
2941+
numThreads.reserve(nLoops);
2942+
AffineExpr s0, s1;
2943+
bindSymbols(rewriter.getContext(), s0, s1);
2944+
AffineExpr divExpr = s0.ceilDiv(s1);
2945+
for (int i = 0, e = std::min(mixedTileSizes.size(), loopRanges.size());
2946+
i < e; ++i) {
2947+
OpFoldResult numTiles = mixedTileSizes[i];
2948+
if (!isConstantIntValue(numTiles, 0))
2949+
numTiles = affine::makeComposedFoldedAffineApply(
2950+
rewriter, tileableOp.getLoc(), divExpr,
2951+
{loopRanges[i].size, numTiles});
2952+
numThreads.push_back(numTiles);
2953+
}
2954+
options.setMaxNumTiles(numThreads);
2955+
}
2956+
if (mapping) {
2957+
options.setMapping(mapping.value().getValue());
29402958
}
2959+
FailureOr<scf::SCFTilingResult> maybeTilingResult =
2960+
scf::tileUsingSCF(rewriter, tileableOp, options);
29412961

29422962
if (failed(maybeTilingResult))
29432963
return transformOp.emitDefaultSilenceableFailure(tileableOp);
2944-
rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
2964+
rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
29452965

29462966
tilingResult = *maybeTilingResult;
29472967
return DiagnosedSilenceableFailure::success();
@@ -2977,14 +2997,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
29772997
return status;
29782998

29792999
for (Operation *target : state.getPayloadOps(getTarget())) {
2980-
linalg::ForallTilingResult tilingResult;
3000+
scf::SCFTilingResult tilingResult;
29813001
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
29823002
rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
29833003
getMapping(), tilingResult);
29843004
if (!diag.succeeded())
29853005
return diag;
2986-
tileOps.push_back(tilingResult.tileOp);
2987-
tiledOps.push_back(tilingResult.tiledOp);
3006+
tileOps.push_back(tilingResult.loops.front());
3007+
tiledOps.append(tilingResult.tiledOps);
29883008
}
29893009

29903010
transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3462,7 +3482,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
34623482

34633483
// OpBuilder only used to compute attributes.
34643484
OpBuilder b(getContext());
3465-
linalg::ForallTilingResult tilingResult;
3485+
scf::SCFTilingResult tilingResult;
34663486
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
34673487
/*rewriter=*/rewriter,
34683488
/*state=*/state,
@@ -3475,8 +3495,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
34753495
if (!diag.succeeded())
34763496
return diag;
34773497

3478-
results.push_back(tilingResult.tileOp);
3479-
results.push_back(tilingResult.tiledOp);
3498+
results.push_back(tilingResult.loops.front());
3499+
for (auto op : tilingResult.tiledOps)
3500+
results.push_back(op);
34803501
return DiagnosedSilenceableFailure::success();
34813502
}
34823503

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 0 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -304,188 +304,6 @@ static void calculateTileOffsetsAndSizes(
304304
}
305305
}
306306

307-
/// Returns a vector of bools representing if, for each axis, `op` can be tiled
308-
/// without incurring in a race condition and thus it is thread-safe to do the
309-
/// tiling. This is checked by iterating over numThreads and ensuring that the
310-
/// corresponding iterator type is "parallel". If it is not, then we know that
311-
/// such dimension is unsafe to tile.
312-
SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
313-
ArrayRef<OpFoldResult> numThreads) {
314-
auto iterators = linalgOp.getIteratorTypesArray();
315-
SmallVector<bool> safeToTile(numThreads.size(), true);
316-
317-
for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
318-
if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
319-
if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
320-
safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
321-
}
322-
} else {
323-
safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
324-
}
325-
}
326-
return safeToTile;
327-
}
328-
329-
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
330-
/// tiling is specified by the number of tiles/threads `numThreads` and the
331-
/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
332-
/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
333-
/// numThreads[i])`. If non-empty, the `mapping` is added as an
334-
/// attribute to the resulting `scf.forall`. A zero tile sizes indicate
335-
/// that the dimension is not tiled, and can be thought of as tiling by the full
336-
/// size of data.
337-
/// It is the user's responsibility to ensure that `numThreads` is a valid
338-
/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
339-
/// Linalg case). If the dimension is not parallelizable, a warning is issued to
340-
/// notify the user that the generated code is not safe to parallelize. If
341-
/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
342-
/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
343-
static FailureOr<ForallTilingResult> tileToForallOpImpl(
344-
RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
345-
std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
346-
std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
347-
Location loc = op->getLoc();
348-
OpBuilder::InsertionGuard g(b);
349-
350-
SmallVector<Range> loopRanges = op.getIterationDomain(b);
351-
if (loopRanges.empty())
352-
return op->emitOpError("expected non-empty loop ranges");
353-
auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
354-
if (llvm::any_of(loopRanges, hasStrideOne))
355-
return op->emitOpError("only stride-1 supported atm");
356-
357-
// Gather destination tensors.
358-
SmallVector<Value> dest;
359-
if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
360-
return op->emitOpError("failed to get destination tensors");
361-
362-
SmallVector<OpFoldResult> nonZeroNumThreads =
363-
llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
364-
return !isConstantIntValue(ofr, 0);
365-
}));
366-
SmallVector<Value> materializedNonZeroNumThreads =
367-
llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
368-
return getValueOrCreateConstantIndexOp(b, loc, ofr);
369-
}));
370-
371-
LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
372-
if (linalgOp) {
373-
// Check if tiling is thread safe and print a warning if not.
374-
SmallVector<bool> tilingSafety =
375-
safeToTileToForall(b.getContext(), linalgOp, numThreads);
376-
for (size_t i = 0; i < tilingSafety.size(); i++)
377-
if (!tilingSafety[i])
378-
op.emitWarning() << "tiling is not thread safe at axis #" << i;
379-
}
380-
381-
// 1. Create the ForallOp. We don't use the lambda body-builder
382-
// version because we require the use of RewriterBase in the body, so we
383-
// manually move the insertion point to the body below.
384-
scf::ForallOp forallOp = b.create<scf::ForallOp>(
385-
loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
386-
387-
// 2. Fill out the ForallOp body.
388-
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
389-
calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
390-
omitTileOffsetBoundsCheck, nominalTileSizes,
391-
tiledOffsets, tiledSizes);
392-
393-
// 3. Clone the tileable op and update its destination operands to use the
394-
// output bbArgs of the ForallOp.
395-
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
396-
Operation *tiledOp = nullptr;
397-
SmallVector<Value> tiledValues;
398-
{
399-
// 3.a. RAII guard, inserting within forallOp, before terminator.
400-
OpBuilder::InsertionGuard g(b);
401-
b.setInsertionPoint(forallOp.getTerminator());
402-
Operation *clonedOp = b.clone(*op.getOperation());
403-
auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
404-
if (destinationStyleOp) {
405-
for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
406-
// Swap tensor inits with the corresponding block argument of the
407-
// scf.forall op. Memref inits remain as is.
408-
if (isa<TensorType>(outOperand.get().getType())) {
409-
auto *it = llvm::find(dest, outOperand.get());
410-
assert(it != dest.end() && "could not find destination tensor");
411-
unsigned destNum = std::distance(dest.begin(), it);
412-
outOperand.set(destBbArgs[destNum]);
413-
}
414-
}
415-
}
416-
417-
// 4. Tile the cloned op and delete the clone.
418-
FailureOr<TilingResult> tilingResult =
419-
cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
420-
tiledSizes);
421-
if (failed(tilingResult))
422-
return clonedOp->emitError("Failed to tile op: ");
423-
if (tilingResult->tiledOps.size() != 1) {
424-
return clonedOp->emitError("expected a single produced tiled op, got ")
425-
<< tilingResult->tiledOps.size();
426-
}
427-
428-
b.eraseOp(clonedOp);
429-
tiledOp = tilingResult->tiledOps.front();
430-
tiledValues = tilingResult->tiledValues;
431-
}
432-
433-
// 5. Parallel insert back into the result tensor.
434-
for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
435-
tiledValues, destBbArgs)) {
436-
// 5.a. Partial subset information is inserted just before the terminator.
437-
OpBuilder::InsertionGuard g(b);
438-
b.setInsertionPoint(forallOp.getTerminator());
439-
440-
SmallVector<OpFoldResult> resultOffsets, resultSizes;
441-
if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
442-
tiledSizes, resultOffsets,
443-
resultSizes)))
444-
return op->emitOpError("output offsets couldn't be calculated");
445-
SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
446-
447-
// 5.b. Parallel insertions are inserted at the end of the combining
448-
// terminator.
449-
b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
450-
b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
451-
std::get<2>(it), resultOffsets,
452-
resultSizes, strides);
453-
}
454-
return ForallTilingResult{forallOp, tiledOp};
455-
}
456-
457-
FailureOr<ForallTilingResult>
458-
linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
459-
ArrayRef<OpFoldResult> numThreads,
460-
std::optional<ArrayAttr> mapping) {
461-
return tileToForallOpImpl(b, op, numThreads,
462-
/*nominalTileSizes=*/std::nullopt, mapping,
463-
/*omitTileOffsetBoundsCheck=*/false);
464-
}
465-
466-
FailureOr<ForallTilingResult>
467-
linalg::tileToForallOpUsingTileSizes(RewriterBase &b, TilingInterface op,
468-
ArrayRef<OpFoldResult> tileSizes,
469-
std::optional<ArrayAttr> mapping) {
470-
SmallVector<Range> loopRanges = op.getIterationDomain(b);
471-
unsigned nLoops = loopRanges.size();
472-
SmallVector<OpFoldResult> numThreads;
473-
numThreads.reserve(nLoops);
474-
AffineExpr s0, s1;
475-
bindSymbols(b.getContext(), s0, s1);
476-
AffineExpr divExpr = s0.ceilDiv(s1);
477-
for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
478-
OpFoldResult numTiles = std::get<0>(it);
479-
if (!isConstantIntValue(numTiles, 0))
480-
numTiles = makeComposedFoldedAffineApply(
481-
b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
482-
numThreads.push_back(numTiles);
483-
}
484-
return tileToForallOpImpl(b, op, numThreads,
485-
/*nominalTileSizes=*/tileSizes, mapping,
486-
/*omitTileOffsetBoundsCheck=*/true);
487-
}
488-
489307
template <typename LoopTy>
490308
static FailureOr<TiledLinalgOp>
491309
tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,

0 commit comments

Comments
 (0)