Skip to content

Commit 479c8d6

Browse files
authored
Use full slices when tiling by the full loop trip count (to support non-monotonic expressions) (#468)
When tiling a chain of linalg.ops, we can only set the tile sizes of the first one to 0 to say untiled, but producers of it will get a tile size of <loop trip count>. We must return the full slice in those case because the code that computes the slices sizes in the general case doesn't handle non-monotonic affine expressions. Otherwise we would generate invalid code for non-monotonic expressions even if all involved dimensions are effectively untiled.
1 parent 54b4bfb commit 479c8d6

File tree

8 files changed

+46
-77
lines changed

8 files changed

+46
-77
lines changed

mlir/include/mlir/IR/AffineExpr.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,6 @@ class AffineExpr {
110110
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
111111
bool isPureAffine() const;
112112

113-
/// Returns true if this expression is monotonicically increasing with respect
114-
/// to the AffineDimExprs, i.e. increasing the value of any AffineDimExpr will
115-
/// never decrease the value of the result.
116-
bool isMonotonicallyIncreasing() const;
117-
118113
/// Returns the greatest known integral divisor of this affine expression. The
119114
/// result is always positive.
120115
int64_t getLargestKnownDivisor() const;

mlir/include/mlir/IR/AffineMap.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,10 +382,6 @@ class AffineMap {
382382
/// Returns true if the AffineMap represents a symbol-less permutation map.
383383
bool isPermutation() const;
384384

385-
// Returns true if every result is monotonically increasing.
386-
// See AffineExpr::isMonotonicallyIncreasing().
387-
bool isComponentWiseMonotonicallyIncreasing() const;
388-
389385
/// Returns the map consisting of the `resultPos` subset.
390386
AffineMap getSubMap(ArrayRef<unsigned> resultPos) const;
391387

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,16 @@ struct LinalgOpTilingInterface
115115
getTiledImplementation(Operation *op, OpBuilder &b,
116116
ArrayRef<OpFoldResult> offsets,
117117
ArrayRef<OpFoldResult> sizes) const {
118-
// Leave the `sizeBounds` value empty. That is only needed when the `sizes`
119-
// specified could lead to out of bounds accesses.
120118
Location loc = op->getLoc();
121119
LinalgOp linalgOp = cast<LinalgOp>(op);
120+
SmallVector<OpFoldResult> allShapeSizes =
121+
linalgOp.createFlatListOfOperandDims(b, linalgOp.getLoc());
122+
SmallVector<OpFoldResult> sizeBounds =
123+
mlir::affine::makeComposedFoldedMultiResultAffineApply(
124+
b, loc, linalgOp.getShapesToLoopsMap(), allShapeSizes);
122125
SmallVector<Value> valuesToTile = linalgOp->getOperands();
123126
SmallVector<Value> tiledOperands = makeTiledShapes(
124-
b, loc, linalgOp, valuesToTile, offsets, sizes, {}, true);
127+
b, loc, linalgOp, valuesToTile, offsets, sizes, sizeBounds, true);
125128
SmallVector<Operation *> generatedSlices = llvm::map_to_vector(
126129
llvm::make_filter_range(
127130
tiledOperands,

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,23 @@ namespace {
5656
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
5757
//
5858
struct TileCheck : public AffineExprVisitor<TileCheck> {
59-
TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
59+
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> sizeBounds)
60+
: tileSizes(tileSizes), sizeBounds(sizeBounds) {}
6061

6162
void visitDimExpr(AffineDimExpr expr) {
62-
isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
63+
unsigned pos = expr.getPosition();
64+
65+
// This dimension is tiled if the tile size is larger than zero and not
66+
// equal to its domain size (if statically known).
67+
std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
68+
if (tileSize && !sizeBounds.empty()) {
69+
std::optional<int64_t> sizeBound = getConstantIntValue(sizeBounds[pos]);
70+
if (sizeBound && *sizeBound == *tileSize) {
71+
return;
72+
}
73+
}
74+
75+
isTiled |= !isZeroIndex(tileSizes[pos]);
6376
}
6477
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
6578
visit(expr.getLHS());
@@ -70,24 +83,27 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
7083
}
7184
bool isTiled = false;
7285
ArrayRef<OpFoldResult> tileSizes;
86+
ArrayRef<OpFoldResult> sizeBounds;
7387
};
7488

7589
} // namespace
7690

77-
static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
91+
static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
92+
ArrayRef<OpFoldResult> sizeBounds) {
7893
if (!expr)
7994
return false;
80-
TileCheck t(tileSizes);
95+
TileCheck t(tileSizes, sizeBounds);
8196
t.visit(expr);
8297
return t.isTiled;
8398
}
8499

85100
// Checks whether the `map varies with respect to a non-zero `tileSize`.
86-
static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
101+
static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes,
102+
ArrayRef<OpFoldResult> sizeBounds) {
87103
if (!map)
88104
return false;
89105
for (unsigned r = 0; r < map.getNumResults(); ++r)
90-
if (isTiled(map.getResult(r), tileSizes))
106+
if (isTiled(map.getResult(r), tileSizes, sizeBounds))
91107
return true;
92108
return false;
93109
}
@@ -585,7 +601,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
585601
// The offset & size computation below only handles the case when
586602
// the map is monotonically increasing, i.e. the min and max values are
587603
// attained at the lower and upper bounds of the iteration domain.
588-
if (!isTiled(m, tileSizes) || !m.isComponentWiseMonotonicallyIncreasing()) {
604+
if (!isTiled(m, tileSizes, ubs)) {
589605
sliceParams.offsets.push_back(builder.getIndexAttr(0));
590606
OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
591607
sliceParams.sizes.push_back(dim);
@@ -784,10 +800,9 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
784800
// transformations such as padding and bufferization since the
785801
// extract/insert slice pairs make the accessed iteration argument
786802
// subdomains explicit.
787-
788803
Type operandType = opOperand.get().getType();
789-
if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
790-
linalgOp.isDpsInit(&opOperand))) {
804+
if (!isTiled(map, tileSizes, {}) && !(isa<RankedTensorType>(operandType) &&
805+
linalgOp.isDpsInit(&opOperand))) {
791806
allSliceParams.push_back(std::nullopt);
792807
LLVM_DEBUG(llvm::dbgs()
793808
<< ": not tiled: use shape: " << operandType << "\n");

mlir/lib/IR/AffineExpr.cpp

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -239,42 +239,6 @@ bool AffineExpr::isPureAffine() const {
239239
llvm_unreachable("Unknown AffineExpr");
240240
}
241241

242-
static bool isNonNegativeConstant(AffineExpr expr) {
243-
auto constant = dyn_cast<AffineConstantExpr>(expr);
244-
return constant && constant.getValue() >= 0;
245-
}
246-
247-
bool AffineExpr::isMonotonicallyIncreasing() const {
248-
switch (getKind()) {
249-
case AffineExprKind::SymbolId:
250-
case AffineExprKind::DimId:
251-
case AffineExprKind::Constant:
252-
return true;
253-
case AffineExprKind::Add: {
254-
auto op = llvm::cast<AffineBinaryOpExpr>(*this);
255-
return op.getLHS().isMonotonicallyIncreasing() &&
256-
op.getRHS().isMonotonicallyIncreasing();
257-
}
258-
case AffineExprKind::Mul: {
259-
// One operand must be a non-negative constant.
260-
auto op = llvm::cast<AffineBinaryOpExpr>(*this);
261-
return op.getLHS().isMonotonicallyIncreasing() &&
262-
op.getRHS().isMonotonicallyIncreasing() &&
263-
(isNonNegativeConstant(op.getLHS()) ||
264-
isNonNegativeConstant(op.getRHS()));
265-
}
266-
case AffineExprKind::FloorDiv:
267-
case AffineExprKind::CeilDiv: {
268-
auto op = llvm::cast<AffineBinaryOpExpr>(*this);
269-
return op.getLHS().isMonotonicallyIncreasing() &&
270-
isNonNegativeConstant(op.getRHS());
271-
}
272-
case AffineExprKind::Mod:
273-
return false;
274-
}
275-
llvm_unreachable("Unknown AffineExpr");
276-
}
277-
278242
// Returns the greatest known integral divisor of this affine expression.
279243
int64_t AffineExpr::getLargestKnownDivisor() const {
280244
AffineBinaryOpExpr binExpr(nullptr);

mlir/lib/IR/AffineMap.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -651,11 +651,6 @@ bool AffineMap::isPermutation() const {
651651
return isProjectedPermutation();
652652
}
653653

654-
bool AffineMap::isComponentWiseMonotonicallyIncreasing() const {
655-
return all_of(getResults(),
656-
[](auto expr) { return expr.isMonotonicallyIncreasing(); });
657-
}
658-
659654
AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const {
660655
SmallVector<AffineExpr, 4> exprs;
661656
exprs.reserve(resultPos.size());

mlir/test/Dialect/Linalg/tile-tensors.mlir

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -171,31 +171,31 @@ module attributes {transform.with_named_sequence} {
171171
// -----
172172

173173
// CHECK-LABEL: func @non_monotonic_affine_expr
174-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?xf32>
175-
func.func @non_monotonic_affine_expr(%arg0 : tensor<?xf32>) -> tensor<?xf32> {
174+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<7xf32>
175+
func.func @non_monotonic_affine_expr(%arg0 : tensor<7xf32>) -> tensor<7xf32> {
176176
%c0 = arith.constant 0 : index
177-
%0 = tensor.dim %arg0, %c0 : tensor<?xf32>
178-
%empty = tensor.empty(%0) : tensor<?xf32>
177+
%0 = tensor.dim %arg0, %c0 : tensor<7xf32>
178+
%empty = tensor.empty() : tensor<7xf32>
179179

180-
// CHECK: scf.for
181-
// CHECK: %[[SIZE:[a-zA-Z0-9_]+]] = tensor.dim %[[ARG0]],
182-
// CHECK: tensor.extract_slice %[[ARG0]][0] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
180+
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<7xf32>
181+
// CHECK: scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[OUT]]) -> (tensor<7xf32>) {
182+
// CHECK: tensor.extract_slice %[[TC0]][0] [7] [1] : tensor<7xf32> to tensor<7xf32>
183183
%generic = linalg.generic
184-
{indexing_maps = [affine_map<(d0) -> (d0 mod 3)>,
184+
{indexing_maps = [affine_map<(d0) -> (d0 mod 4)>,
185185
affine_map<(d0) -> (d0)>],
186186
iterator_types = ["parallel"]}
187-
ins(%arg0: tensor<?xf32>)
188-
outs(%empty : tensor<?xf32>) {
187+
ins(%arg0: tensor<7xf32>)
188+
outs(%empty : tensor<7xf32>) {
189189
^bb0(%in : f32, %out: f32):
190190
linalg.yield %in : f32
191-
} -> tensor<?xf32>
192-
return %generic : tensor<?xf32>
191+
} -> tensor<7xf32>
192+
return %generic : tensor<7xf32>
193193
}
194194

195195
module attributes {transform.with_named_sequence} {
196196
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
197197
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
198-
%1, %loop = transform.structured.tile_using_for %0 tile_sizes [100] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
198+
%1, %loop = transform.structured.tile_using_for %0 tile_sizes [7] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
199199
transform.yield
200200
}
201201
}

mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,11 +555,12 @@ module {
555555

556556
// CHECK: %[[T1:.*]] = linalg.generic {{.*}}
557557
// CHECK: %[[T2:.*]] = linalg.generic {{.*}}
558+
// CHECK: %[[T3:.*]] = linalg.generic {{.*}}
558559
%7 = tensor.extract_slice %1[%4] [%5] [1] : tensor<?xf32> to tensor<?xf32>
559560

560561
%8 = linalg.elemwise_unary ins(%7 : tensor<?xf32>) outs(%6 : tensor<?xf32>) -> tensor<?xf32>
561562
scf.forall.in_parallel {
562-
// CHECK: tensor.parallel_insert_slice %[[T2]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
563+
// CHECK: tensor.parallel_insert_slice %[[T3]] into %[[ARG7]][%[[I0]]] [%[[I1]]] [1] : tensor<?xf32> into tensor<?xf32>
563564
tensor.parallel_insert_slice %8 into %o[%2] [%5] [1] : tensor<?xf32> into tensor<?xf32>
564565
}
565566
}

0 commit comments

Comments
 (0)