Skip to content

Commit c2c13ce

Browse files
Add logic to account for negative tile sizes.
1 parent 38398d0 commit c2c13ce

File tree

2 files changed

+102
-49
lines changed

2 files changed

+102
-49
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -220,45 +220,93 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
220220
b, loc, minMap, SmallVector<OpFoldResult>{offset, tileSize, size});
221221
}
222222

223+
/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
224+
/// than `iterationSize`.
225+
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
226+
OpFoldResult numThreads,
227+
OpFoldResult iterationSize) {
228+
std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
229+
std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
230+
std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
231+
if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
232+
return false;
233+
return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
234+
}
235+
223236
/// Compute the tile offsets and sizes.
224237
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
225238
getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
226239
ArrayRef<Range> iterationDomain,
227-
ArrayRef<OpFoldResult> tileSizes, bool isLoopNormalized) {
240+
ArrayRef<OpFoldResult> tileSizes,
241+
ArrayRef<OpFoldResult> numThreads) {
228242
SmallVector<OpFoldResult> offsets, sizes;
229243
int materializedLoopNum = 0;
230244

231-
AffineExpr d0, s0, s1, s2;
232-
AffineExpr offsetExpr;
233-
if (isLoopNormalized) {
234-
bindDims(rewriter.getContext(), d0);
245+
if (!numThreads.empty()) {
246+
AffineExpr d0, d1, s0, s1, s2;
247+
AffineExpr offsetExpr, residualTileSizeExpr;
248+
bindDims(rewriter.getContext(), d0, d1);
235249
bindSymbols(rewriter.getContext(), s0, s1, s2);
236-
offsetExpr = s0 + d0 * s1 * s2;
237-
}
250+
offsetExpr = d0 + d1 * s0 * s1;
251+
residualTileSizeExpr = s2 - (d0 + d1 * s0 * s1);
238252

239-
for (auto [tileSize, loopRange] :
240-
llvm::zip_equal(tileSizes, iterationDomain)) {
241-
if (isConstantIntValue(tileSize, 0)) {
242-
offsets.push_back(loopRange.offset);
243-
sizes.push_back(loopRange.size);
244-
continue;
245-
}
246-
// If loop is normalized, the offset is (lb + iv * step * tileSize)
247-
Value iv = ivs[materializedLoopNum++];
248-
OpFoldResult offset;
249-
if (isLoopNormalized) {
250-
offset = affine::makeComposedFoldedAffineApply(
253+
for (auto [nt, tileSize, loopRange] :
254+
llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
255+
256+
if (isConstantIntValue(nt, 0) || isConstantIntValue(nt, 1)) {
257+
offsets.push_back(loopRange.offset);
258+
sizes.push_back(loopRange.size);
259+
continue;
260+
}
261+
262+
Value iv = ivs[materializedLoopNum++];
263+
OpFoldResult offset = affine::makeComposedFoldedAffineApply(
251264
rewriter, loc, offsetExpr,
252-
ArrayRef<OpFoldResult>{iv, loopRange.offset, loopRange.stride,
265+
ArrayRef<OpFoldResult>{loopRange.offset, iv, loopRange.stride,
253266
tileSize});
254-
} else {
255-
offset = getAsOpFoldResult(iv);
267+
OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
268+
rewriter, loc, residualTileSizeExpr,
269+
{loopRange.offset, nt, loopRange.stride, tileSize, loopRange.size});
270+
OpFoldResult size = tileSize;
271+
if (!isConstantIntValue(residualTileSize, 0)) {
272+
OpFoldResult sizeMinusOffsetPerThread =
273+
affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
274+
{offset, loopRange.size});
275+
size = affine::makeComposedFoldedAffineMin(
276+
rewriter, loc,
277+
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
278+
{sizeMinusOffsetPerThread, tileSize});
279+
}
280+
if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
281+
AffineMap maxMap =
282+
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
283+
size = affine::makeComposedFoldedAffineMax(
284+
rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
285+
}
286+
287+
offsets.push_back(offset);
288+
sizes.push_back(size);
289+
}
290+
return {offsets, sizes};
291+
} else {
292+
for (auto [tileSize, loopRange] :
293+
llvm::zip_equal(tileSizes, iterationDomain)) {
294+
295+
if (isConstantIntValue(tileSize, 0)) {
296+
offsets.push_back(loopRange.offset);
297+
sizes.push_back(loopRange.size);
298+
continue;
299+
}
300+
301+
Value iv = ivs[materializedLoopNum++];
302+
OpFoldResult offset = getAsOpFoldResult(iv);
303+
offsets.push_back(offset);
304+
OpFoldResult size =
305+
getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
306+
sizes.push_back(size);
256307
}
257-
offsets.push_back(offset);
258-
sizes.push_back(
259-
getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize));
308+
return {offsets, sizes};
260309
}
261-
return {offsets, sizes};
262310
}
263311

264312
/// Function to return the bounds of the loops to be generated.
@@ -765,7 +813,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
765813
// 4a. Compute the `offsets` and `sizes` to use for tiling.
766814
SmallVector<OpFoldResult> offsets, sizes;
767815
std::tie(offsets, sizes) = getTileOffsetAndSizes(
768-
rewriter, loc, ivs, iterationDomain, tileSizes, !numThreads.empty());
816+
rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
769817

770818
// 4b. If interchange was provided, apply inverse of the interchange
771819
// to get back the offsets/sizes in the order to be specified.

mlir/test/Dialect/Linalg/tile-to-forall.mlir

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
// Offset per thread:
44
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
55
// Per thread tile size.
6-
// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 10, -(d0 * (s0 ceildiv 10)) + s0)>
6+
// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 10)) + s0, s0 ceildiv 10)>
77
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 20))>
8-
// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 20, -(d0 * (s0 ceildiv 20)) + s0)>
8+
// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 20)) + s0, s0 ceildiv 20)>
99

1010
module {
1111
// CHECK-LABEL: matmul(
@@ -96,7 +96,7 @@ module {
9696
// In this test case, matmul dims and tile size are dynamic.
9797

9898
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
99-
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
99+
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
100100
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
101101

102102
// CHECK-LABEL: matmul_tile_size_dynamic_dynamic(
@@ -140,7 +140,7 @@ module attributes {transform.with_named_sequence} {
140140

141141
// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
142142

143-
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (15, d0 * -15 + 300)>
143+
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>
144144
// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
145145
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
146146
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 15)>
@@ -176,6 +176,7 @@ module attributes {transform.with_named_sequence} {
176176
transform.yield
177177
}
178178
}
179+
179180
// -----
180181

181182
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
@@ -296,7 +297,7 @@ module {
296297

297298
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
298299
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
299-
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
300+
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
300301
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
301302
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
302303
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
@@ -339,7 +340,6 @@ module attributes {transform.with_named_sequence} {
339340
// -----
340341

341342
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 100, 15)>
342-
// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
343343
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 15)>
344344
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0)>
345345

@@ -352,8 +352,7 @@ module attributes {transform.with_named_sequence} {
352352
%OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>)
353353
-> (tensor<100xf32>, tensor<100xf32>) {
354354
// CHECK: scf.forall (%[[IV0:.+]]) in (7) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
355-
// CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]])
356-
// CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
355+
// CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV0]])
357356
// CHECK-NOT: affine.min
358357
// CHECK-NOT: affine.max
359358
// CHECK: %[[LB:.+]] = affine.apply #[[$map2]](%[[IV0]])
@@ -453,9 +452,10 @@ module attributes {transform.with_named_sequence} {
453452
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
454453
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
455454
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
456-
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
457-
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
458-
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
455+
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
456+
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
457+
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
458+
// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
459459

460460
// CHECK-LABEL: matmul_tile_size_dynamic(
461461
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -470,10 +470,12 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
470470
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
471471
// CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
472472
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
473-
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
474-
// CHECK: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
475-
// CHECK: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
476-
// CHECK: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
473+
// CHECK: %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
474+
// CHECK: %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
475+
// CHECK: %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
476+
// CHECK: %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
477+
// CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
478+
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
477479
// CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
478480
// CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
479481
// CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
@@ -521,9 +523,10 @@ module attributes {transform.with_named_sequence} {
521523
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
522524
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
523525
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
524-
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
525-
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
526-
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
526+
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
527+
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
528+
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
529+
// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
527530

528531
// CHECK-LABEL: matmul_tile_size_dynamic(
529532
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -538,10 +541,12 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
538541
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
539542
// CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
540543
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
541-
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
542-
// CHECK: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
543-
// CHECK: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
544-
// CHECK: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
544+
// CHECK: %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
545+
// CHECK: %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
546+
// CHECK: %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
547+
// CHECK: %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
548+
// CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
549+
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
545550
// CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
546551
// CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
547552
// CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :

0 commit comments

Comments
 (0)