Skip to content

Commit 0bb323b

Browse files
Allow specifying both numThreads and tileSizes to keep the same existing semantics of distribution using number of threads.
1 parent ddca736 commit 0bb323b

File tree

7 files changed

+234
-206
lines changed

7 files changed

+234
-206
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ struct SCFTilingOptions {
3535
/// Returning a tile size of zero implies no tiling for that loop. If the
3636
/// size of the returned vector is smaller than the number of loops, the inner
3737
/// loops are not tiled. If the size of the returned vector is larger, then
38-
/// the vector is truncated to number of loops. Only one of
39-
/// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
40-
/// be used.
38+
/// the vector is truncated to number of loops.
4139
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
4240

4341
SCFTilingOptions &
@@ -50,23 +48,25 @@ struct SCFTilingOptions {
5048
/// proper interaction with folding.
5149
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
5250

53-
/// Computation function that returns the maximum number of tile to use for
54-
/// each loop. Returning a tile size of zero implies no tiling for that loop.
55-
/// If the size of the returned vector is smaller than the number of loops,
56-
/// the inner loops are not tiled. If the size of the returned vector is
57-
/// larger, then the vector is truncated to number of loops. Only one of
58-
/// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
59-
/// be used.
60-
SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
51+
/// Computation function that returns the number of threads to use for
52+
/// each loop. Returning a num threads of zero implies no tiling for that
53+
/// loop. If the size of the returned vector is smaller than the number of
54+
/// loops, the inner loops are not tiled. If the size of the returned vector
55+
/// is larger, then the vector is truncated to number of loops. Note: This
56+
/// option is only supported with loopType set to `LoopType::ForallOp`. If the
57+
/// tile size function is not specified while the num threads computation is,
58+
/// then the tile size is determined automatically to map at most one tile per
59+
/// thread.
60+
SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
6161

6262
SCFTilingOptions &
63-
setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
64-
maxNumTilesComputationFunction = std::move(fun);
63+
setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
64+
numThreadsComputationFunction = std::move(fun);
6565
return *this;
6666
}
6767
/// Convenience function to set the `tileSizeComputationFunction` to a
6868
/// function that computes tile sizes at the point they are needed.
69-
SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
69+
SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
7070

7171
/// The interchange vector to reorder the tiled loops.
7272
SmallVector<int64_t> interchangeVector = {};

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2931,7 +2931,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
29312931
scf::SCFTilingOptions options;
29322932
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
29332933
if (!mixedNumThreads.empty()) {
2934-
options.setMaxNumTiles(mixedNumThreads);
2934+
options.setNumThreads(mixedNumThreads);
29352935
} else {
29362936
SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
29372937
unsigned nLoops = loopRanges.size();
@@ -2949,7 +2949,8 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
29492949
{loopRanges[i].size, numTiles});
29502950
numThreads.push_back(numTiles);
29512951
}
2952-
options.setMaxNumTiles(numThreads);
2952+
options.setNumThreads(numThreads);
2953+
options.setTileSizes(mixedTileSizes);
29532954
}
29542955
if (mapping) {
29552956
options.setMapping(mapping.value().getValue());

0 commit comments

Comments
 (0)