Skip to content

Commit 0b00056

Browse files
Add logic to account for negative tile sizes.
1 parent 58296c9 commit 0b00056

File tree

2 files changed

+102
-49
lines changed

2 files changed

+102
-49
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -219,45 +219,93 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
219219
b, loc, minMap, SmallVector<OpFoldResult>{offset, tileSize, size});
220220
}
221221

222+
/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
223+
/// than `iterationSize`.
224+
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
225+
OpFoldResult numThreads,
226+
OpFoldResult iterationSize) {
227+
std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
228+
std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
229+
std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
230+
if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
231+
return false;
232+
return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
233+
}
234+
222235
/// Compute the tile offsets and sizes.
223236
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
224237
getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
225238
ArrayRef<Range> iterationDomain,
226-
ArrayRef<OpFoldResult> tileSizes, bool isLoopNormalized) {
239+
ArrayRef<OpFoldResult> tileSizes,
240+
ArrayRef<OpFoldResult> numThreads) {
227241
SmallVector<OpFoldResult> offsets, sizes;
228242
int materializedLoopNum = 0;
229243

230-
AffineExpr d0, s0, s1, s2;
231-
AffineExpr offsetExpr;
232-
if (isLoopNormalized) {
233-
bindDims(rewriter.getContext(), d0);
244+
if (!numThreads.empty()) {
245+
AffineExpr d0, d1, s0, s1, s2;
246+
AffineExpr offsetExpr, residualTileSizeExpr;
247+
bindDims(rewriter.getContext(), d0, d1);
234248
bindSymbols(rewriter.getContext(), s0, s1, s2);
235-
offsetExpr = s0 + d0 * s1 * s2;
236-
}
249+
offsetExpr = d0 + d1 * s0 * s1;
250+
residualTileSizeExpr = s2 - (d0 + d1 * s0 * s1);
237251

238-
for (auto [tileSize, loopRange] :
239-
llvm::zip_equal(tileSizes, iterationDomain)) {
240-
if (isConstantIntValue(tileSize, 0)) {
241-
offsets.push_back(loopRange.offset);
242-
sizes.push_back(loopRange.size);
243-
continue;
244-
}
245-
// If loop is normalized, the offset is (lb + iv * step * tileSize)
246-
Value iv = ivs[materializedLoopNum++];
247-
OpFoldResult offset;
248-
if (isLoopNormalized) {
249-
offset = affine::makeComposedFoldedAffineApply(
252+
for (auto [nt, tileSize, loopRange] :
253+
llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
254+
255+
if (isConstantIntValue(nt, 0) || isConstantIntValue(nt, 1)) {
256+
offsets.push_back(loopRange.offset);
257+
sizes.push_back(loopRange.size);
258+
continue;
259+
}
260+
261+
Value iv = ivs[materializedLoopNum++];
262+
OpFoldResult offset = affine::makeComposedFoldedAffineApply(
250263
rewriter, loc, offsetExpr,
251-
ArrayRef<OpFoldResult>{iv, loopRange.offset, loopRange.stride,
264+
ArrayRef<OpFoldResult>{loopRange.offset, iv, loopRange.stride,
252265
tileSize});
253-
} else {
254-
offset = getAsOpFoldResult(iv);
266+
OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
267+
rewriter, loc, residualTileSizeExpr,
268+
{loopRange.offset, nt, loopRange.stride, tileSize, loopRange.size});
269+
OpFoldResult size = tileSize;
270+
if (!isConstantIntValue(residualTileSize, 0)) {
271+
OpFoldResult sizeMinusOffsetPerThread =
272+
affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
273+
{offset, loopRange.size});
274+
size = affine::makeComposedFoldedAffineMin(
275+
rewriter, loc,
276+
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
277+
{sizeMinusOffsetPerThread, tileSize});
278+
}
279+
if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
280+
AffineMap maxMap =
281+
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
282+
size = affine::makeComposedFoldedAffineMax(
283+
rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
284+
}
285+
286+
offsets.push_back(offset);
287+
sizes.push_back(size);
288+
}
289+
return {offsets, sizes};
290+
} else {
291+
for (auto [tileSize, loopRange] :
292+
llvm::zip_equal(tileSizes, iterationDomain)) {
293+
294+
if (isConstantIntValue(tileSize, 0)) {
295+
offsets.push_back(loopRange.offset);
296+
sizes.push_back(loopRange.size);
297+
continue;
298+
}
299+
300+
Value iv = ivs[materializedLoopNum++];
301+
OpFoldResult offset = getAsOpFoldResult(iv);
302+
offsets.push_back(offset);
303+
OpFoldResult size =
304+
getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
305+
sizes.push_back(size);
255306
}
256-
offsets.push_back(offset);
257-
sizes.push_back(
258-
getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize));
307+
return {offsets, sizes};
259308
}
260-
return {offsets, sizes};
261309
}
262310

263311
/// Function to return the bounds of the loops to be generated.
@@ -764,7 +812,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
764812
// 4a. Compute the `offsets` and `sizes` to use for tiling.
765813
SmallVector<OpFoldResult> offsets, sizes;
766814
std::tie(offsets, sizes) = getTileOffsetAndSizes(
767-
rewriter, loc, ivs, iterationDomain, tileSizes, !numThreads.empty());
815+
rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
768816

769817
// 4b. If interchange was provided, apply inverse of the interchange
770818
// to get back the offsets/sizes in the order to be specified.

mlir/test/Dialect/Linalg/tile-to-forall.mlir

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
// Offset per thread:
44
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
55
// Per thread tile size.
6-
// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 10, -(d0 * (s0 ceildiv 10)) + s0)>
6+
// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 10)) + s0, s0 ceildiv 10)>
77
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 20))>
8-
// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 20, -(d0 * (s0 ceildiv 20)) + s0)>
8+
// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 20)) + s0, s0 ceildiv 20)>
99

1010
module {
1111
// CHECK-LABEL: matmul(
@@ -96,7 +96,7 @@ module {
9696
// In this test case, matmul dims and tile size are dynamic.
9797

9898
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
99-
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
99+
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
100100
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
101101

102102
// CHECK-LABEL: matmul_tile_size_dynamic_dynamic(
@@ -140,7 +140,7 @@ module attributes {transform.with_named_sequence} {
140140

141141
// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
142142

143-
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (15, d0 * -15 + 300)>
143+
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>
144144
// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
145145
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
146146
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 15)>
@@ -176,6 +176,7 @@ module attributes {transform.with_named_sequence} {
176176
transform.yield
177177
}
178178
}
179+
179180
// -----
180181

181182
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
@@ -296,7 +297,7 @@ module {
296297

297298
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
298299
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
299-
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
300+
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
300301
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
301302
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
302303
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
@@ -339,7 +340,6 @@ module attributes {transform.with_named_sequence} {
339340
// -----
340341

341342
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 100, 15)>
342-
// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
343343
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 15)>
344344
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0)>
345345

@@ -352,8 +352,7 @@ module attributes {transform.with_named_sequence} {
352352
%OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>)
353353
-> (tensor<100xf32>, tensor<100xf32>) {
354354
// CHECK: scf.forall (%[[IV0:.+]]) in (7) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
355-
// CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]])
356-
// CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
355+
// CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV0]])
357356
// CHECK-NOT: affine.min
358357
// CHECK-NOT: affine.max
359358
// CHECK: %[[LB:.+]] = affine.apply #[[$map2]](%[[IV0]])
@@ -453,9 +452,10 @@ module attributes {transform.with_named_sequence} {
453452
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
454453
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
455454
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
456-
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
457-
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
458-
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
455+
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
456+
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
457+
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
458+
// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
459459

460460
// CHECK-LABEL: matmul_tile_size_dynamic(
461461
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -470,10 +470,12 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
470470
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
471471
// CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
472472
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
473-
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
474-
// CHECK: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
475-
// CHECK: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
476-
// CHECK: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
473+
// CHECK: %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
474+
// CHECK: %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
475+
// CHECK: %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
476+
// CHECK: %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
477+
// CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
478+
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
477479
// CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
478480
// CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
479481
// CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
@@ -521,9 +523,10 @@ module attributes {transform.with_named_sequence} {
521523
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
522524
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
523525
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
524-
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
525-
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
526-
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
526+
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
527+
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
528+
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
529+
// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
527530

528531
// CHECK-LABEL: matmul_tile_size_dynamic(
529532
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -538,10 +541,12 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
538541
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
539542
// CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
540543
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
541-
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
542-
// CHECK: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
543-
// CHECK: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
544-
// CHECK: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
544+
// CHECK: %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
545+
// CHECK: %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
546+
// CHECK: %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
547+
// CHECK: %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
548+
// CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
549+
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
545550
// CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
546551
// CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
547552
// CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :

0 commit comments

Comments
 (0)