Skip to content

Commit c828576

Browse files
Next round of comments.
1 parent 1857e5c commit c828576

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
7575
static LogicalResult
7676
verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
7777
const scf::SCFTilingOptions &options) {
78-
// Specifying number of tile is only supported on `scf.forall` op.
78+
// Specifying number of threads is only supported on `scf.forall` op.
7979
if (options.numThreadsComputationFunction &&
8080
options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
8181
return rewriter.notifyMatchFailure(
82-
loc, "number of tiles/threads can only by specified when loop type is "
82+
loc, "number of threads can only by specified when loop type is "
8383
"set to use `scf.forall`");
8484
}
8585

@@ -111,25 +111,27 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
111111
// If the number of tiles is also specified, use that.
112112
if (options.tileSizeComputationFunction) {
113113
tileSizes = options.tileSizeComputationFunction(rewriter, op);
114-
} else {
115-
// Compute the tile sizes from the iteration domain and number
116-
// of tiles as follows
117-
// - niters = ceilDiv(ub - lb, step)
118-
// - tileSize = ceilDiv(niters, numThreads)
119-
AffineExpr s0, s1, s2, s3;
120-
bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
121-
AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
122-
AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
123114
tileSizes.resize(numLoops, zero);
124-
for (auto [index, range, nt] :
125-
llvm::enumerate(iterationDomain, numThreads)) {
126-
if (isConstantIntValue(nt, 0))
127-
continue;
115+
return {tileSizes, numThreads};
116+
}
128117

129-
tileSizes[index] = affine::makeComposedFoldedAffineApply(
130-
rewriter, op.getLoc(), tileSizeExpr,
131-
{range.offset, range.size, range.stride, nt});
132-
}
118+
// Compute the tile sizes from the iteration domain and number
119+
// of tiles as follows
120+
// - niters = ceilDiv(ub - lb, step)
121+
// - tileSize = ceilDiv(niters, numThreads)
122+
AffineExpr s0, s1, s2, s3;
123+
bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
124+
AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
125+
AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
126+
tileSizes.resize(numLoops, zero);
127+
for (auto [index, range, nt] :
128+
llvm::enumerate(iterationDomain, numThreads)) {
129+
if (isConstantIntValue(nt, 0))
130+
continue;
131+
132+
tileSizes[index] = affine::makeComposedFoldedAffineApply(
133+
rewriter, op.getLoc(), tileSizeExpr,
134+
{range.offset, range.size, range.stride, nt});
133135
}
134136
tileSizes.resize(numLoops, zero);
135137
return {tileSizes, numThreads};
@@ -139,9 +141,9 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
139141
// skips tiling a particular dimension. This convention is significantly
140142
// simpler to handle instead of adjusting affine maps to account for missing
141143
// dimensions.
142-
if (options.tileSizeComputationFunction) {
143-
tileSizes = options.tileSizeComputationFunction(rewriter, op);
144-
}
144+
assert(options.tileSizeComputationFunction &&
145+
"expected tile sizes to be specified");
146+
tileSizes = options.tileSizeComputationFunction(rewriter, op);
145147
tileSizes.resize(numLoops, zero);
146148

147149
return {tileSizes, numThreads};

0 commit comments

Comments
 (0)