@@ -75,11 +75,11 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
75
75
static LogicalResult
76
76
verifyTileSizeOptions (RewriterBase &rewriter, Location loc,
77
77
const scf::SCFTilingOptions &options) {
78
- // Specifying number of tile is only supported on `scf.forall` op.
78
+ // Specifying number of threads is only supported on `scf.forall` op.
79
79
if (options.numThreadsComputationFunction &&
80
80
options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
81
81
return rewriter.notifyMatchFailure (
82
- loc, " number of tiles/ threads can only by specified when loop type is "
82
+ loc, " number of threads can only by specified when loop type is "
83
83
" set to use `scf.forall`" );
84
84
}
85
85
@@ -111,25 +111,27 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
111
111
// If the number of tiles is also specified, use that.
112
112
if (options.tileSizeComputationFunction ) {
113
113
tileSizes = options.tileSizeComputationFunction (rewriter, op);
114
- } else {
115
- // Compute the tile sizes from the iteration domain and number
116
- // of tiles as follows
117
- // - niters = ceilDiv(ub - lb, step)
118
- // - tileSize = ceilDiv(niters, numThreads)
119
- AffineExpr s0, s1, s2, s3;
120
- bindSymbols (rewriter.getContext (), s0, s1, s2, s3);
121
- AffineExpr numItersExpr = (s1 - s0).ceilDiv (s2);
122
- AffineExpr tileSizeExpr = numItersExpr.ceilDiv (s3);
123
114
tileSizes.resize (numLoops, zero);
124
- for (auto [index, range, nt] :
125
- llvm::enumerate (iterationDomain, numThreads)) {
126
- if (isConstantIntValue (nt, 0 ))
127
- continue ;
115
+ return {tileSizes, numThreads};
116
+ }
128
117
129
- tileSizes[index] = affine::makeComposedFoldedAffineApply (
130
- rewriter, op.getLoc (), tileSizeExpr,
131
- {range.offset , range.size , range.stride , nt});
132
- }
118
+ // Compute the tile sizes from the iteration domain and number
119
+ // of tiles as follows
120
+ // - niters = ceilDiv(ub - lb, step)
121
+ // - tileSize = ceilDiv(niters, numThreads)
122
+ AffineExpr s0, s1, s2, s3;
123
+ bindSymbols (rewriter.getContext (), s0, s1, s2, s3);
124
+ AffineExpr numItersExpr = (s1 - s0).ceilDiv (s2);
125
+ AffineExpr tileSizeExpr = numItersExpr.ceilDiv (s3);
126
+ tileSizes.resize (numLoops, zero);
127
+ for (auto [index, range, nt] :
128
+ llvm::enumerate (iterationDomain, numThreads)) {
129
+ if (isConstantIntValue (nt, 0 ))
130
+ continue ;
131
+
132
+ tileSizes[index] = affine::makeComposedFoldedAffineApply (
133
+ rewriter, op.getLoc (), tileSizeExpr,
134
+ {range.offset , range.size , range.stride , nt});
133
135
}
134
136
tileSizes.resize (numLoops, zero);
135
137
return {tileSizes, numThreads};
@@ -139,9 +141,9 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
139
141
// skips tiling a particular dimension. This convention is significantly
140
142
// simpler to handle instead of adjusting affine maps to account for missing
141
143
// dimensions.
142
- if (options.tileSizeComputationFunction ) {
143
- tileSizes = options. tileSizeComputationFunction (rewriter, op );
144
- }
144
+ assert (options.tileSizeComputationFunction &&
145
+ " expected tile sizes to be specified " );
146
+ tileSizes = options. tileSizeComputationFunction (rewriter, op);
145
147
tileSizes.resize (numLoops, zero);
146
148
147
149
return {tileSizes, numThreads};
0 commit comments