Skip to content

Commit 38398d0

Browse files
Allow specifying both numThreads and tileSizes to keep the same existing semantics of distribution using number of threads.
1 parent 4151c40 commit 38398d0

File tree

7 files changed

+234
-206
lines changed

7 files changed

+234
-206
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ struct SCFTilingOptions {
3636
/// Returning a tile size of zero implies no tiling for that loop. If the
3737
/// size of the returned vector is smaller than the number of loops, the inner
3838
/// loops are not tiled. If the size of the returned vector is larger, then
39-
/// the vector is truncated to number of loops. Only one of
40-
/// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
41-
/// be used.
39+
/// the vector is truncated to number of loops.
4240
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
4341

4442
SCFTilingOptions &
@@ -51,23 +49,25 @@ struct SCFTilingOptions {
5149
/// proper interaction with folding.
5250
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
5351

54-
/// Computation function that returns the maximum number of tile to use for
55-
/// each loop. Returning a tile size of zero implies no tiling for that loop.
56-
/// If the size of the returned vector is smaller than the number of loops,
57-
/// the inner loops are not tiled. If the size of the returned vector is
58-
/// larger, then the vector is truncated to number of loops. Only one of
59-
/// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
60-
/// be used.
61-
SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
52+
/// Computation function that returns the number of threads to use for
53+
/// each loop. Returning a num threads of zero implies no tiling for that
54+
/// loop. If the size of the returned vector is smaller than the number of
55+
/// loops, the inner loops are not tiled. If the size of the returned vector
56+
/// is larger, then the vector is truncated to number of loops. Note: This
57+
/// option is only supported with loopType set to `LoopType::ForallOp`. If the
58+
/// tile size function is not specified while the num threads computation is,
59+
/// then the tile size is determined automatically to map at most one tile per
60+
/// thread.
61+
SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
6262

6363
SCFTilingOptions &
64-
setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
65-
maxNumTilesComputationFunction = std::move(fun);
64+
setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
65+
numThreadsComputationFunction = std::move(fun);
6666
return *this;
6767
}
6868
/// Convenience function to set the `tileSizeComputationFunction` to a
6969
/// function that computes tile sizes at the point they are needed.
70-
SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
70+
SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
7171

7272
/// The interchange vector to reorder the tiled loops.
7373
SmallVector<int64_t> interchangeVector = {};

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2933,7 +2933,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
29332933
scf::SCFTilingOptions options;
29342934
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
29352935
if (!mixedNumThreads.empty()) {
2936-
options.setMaxNumTiles(mixedNumThreads);
2936+
options.setNumThreads(mixedNumThreads);
29372937
} else {
29382938
SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
29392939
unsigned nLoops = loopRanges.size();
@@ -2951,7 +2951,8 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
29512951
{loopRanges[i].size, numTiles});
29522952
numThreads.push_back(numTiles);
29532953
}
2954-
options.setMaxNumTiles(numThreads);
2954+
options.setNumThreads(numThreads);
2955+
options.setTileSizes(mixedTileSizes);
29552956
}
29562957
if (mapping) {
29572958
options.setMapping(mapping.value().getValue());

0 commit comments

Comments
 (0)