Skip to content

Commit a067405

Browse files
committed
Tiling for full tile size
1 parent 8bf67e1 commit a067405

File tree

2 files changed

+44
-27
lines changed

2 files changed

+44
-27
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1313
#include "mlir/Dialect/SCF/IR/SCF.h"
1414
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
15+
#include "llvm/ADT/ArrayRef.h"
1516
#include "llvm/ADT/StringSet.h"
1617
#include <optional>
1718

@@ -143,12 +144,11 @@ struct SliceParameters {
143144
///
144145
/// `omitPartialTileCheck` controls whether to omit the partial/boundary tile
145146
/// condition check in cases where we statically know that it is unnecessary.
146-
SliceParameters
147-
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
148-
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
149-
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
150-
ArrayRef<OpFoldResult> subShapeSizes,
151-
bool omitPartialTileCheck);
147+
SliceParameters computeSliceParameters(
148+
OpBuilder &builder, Location loc, Value valueToTile,
149+
ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
150+
ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
151+
bool omitPartialTileCheck, ArrayRef<int64_t> domainSizes = {});
152152

153153
/// Computes SliceParamaters for all `valuesToTile` of the given `linalgOp`,
154154
/// assuming `linalgOp` is being fused into a loop nest. Calls
@@ -177,7 +177,8 @@ Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
177177
ArrayRef<OpFoldResult> lbs,
178178
ArrayRef<OpFoldResult> ubs,
179179
ArrayRef<OpFoldResult> subShapeSizes,
180-
bool omitPartialTileCheck);
180+
bool omitPartialTileCheck,
181+
ArrayRef<OpFoldResult> sizeBounds = {});
181182

182183
/// Creates extract_slice/subview ops for all `valuesToTile` of the given
183184
/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,21 @@ namespace {
5656
// `d0 + 2 * d1 + d3` is tiled by [0, 0, 0, 2] but not by [0, 0, 2, 0]
5757
//
5858
struct TileCheck : public AffineExprVisitor<TileCheck> {
59-
TileCheck(ArrayRef<OpFoldResult> tileSizes) : tileSizes(tileSizes) {}
59+
TileCheck(ArrayRef<OpFoldResult> tileSizes, ArrayRef<int64_t> domainSizes)
60+
: tileSizes(tileSizes), domainSizes(domainSizes) {}
6061

6162
void visitDimExpr(AffineDimExpr expr) {
62-
isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]);
63+
unsigned pos = expr.getPosition();
64+
65+
// There is no tile if all tile sizes correspond to the domain size
66+
std::optional<int64_t> tileSize = getConstantIntValue(tileSizes[pos]);
67+
if (tileSize && !domainSizes.empty()) {
68+
if (domainSizes[pos] == *tileSize) {
69+
return;
70+
}
71+
}
72+
73+
isTiled |= !isZeroIndex(tileSizes[pos]);
6374
}
6475
void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) {
6576
visit(expr.getLHS());
@@ -70,24 +81,28 @@ struct TileCheck : public AffineExprVisitor<TileCheck> {
7081
}
7182
bool isTiled = false;
7283
ArrayRef<OpFoldResult> tileSizes;
84+
ArrayRef<int64_t> domainSizes;
7385
};
7486

7587
} // namespace
7688

77-
static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes) {
89+
static bool isTiled(AffineExpr expr, ArrayRef<OpFoldResult> tileSizes,
90+
ArrayRef<int64_t> domainSizes) {
7891
if (!expr)
7992
return false;
80-
TileCheck t(tileSizes);
93+
94+
TileCheck t(tileSizes, domainSizes);
8195
t.visit(expr);
8296
return t.isTiled;
8397
}
8498

8599
// Checks whether the `map varies with respect to a non-zero `tileSize`.
86-
static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes) {
100+
static bool isTiled(AffineMap map, ArrayRef<OpFoldResult> tileSizes,
101+
ArrayRef<int64_t> domainSizes) {
87102
if (!map)
88103
return false;
89104
for (unsigned r = 0; r < map.getNumResults(); ++r)
90-
if (isTiled(map.getResult(r), tileSizes))
105+
if (isTiled(map.getResult(r), tileSizes, domainSizes))
91106
return true;
92107
return false;
93108
}
@@ -556,19 +571,19 @@ Operation *makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
556571
ArrayRef<OpFoldResult> lbs,
557572
ArrayRef<OpFoldResult> ubs,
558573
ArrayRef<OpFoldResult> subShapeSizes,
559-
bool omitPartialTileCheck) {
560-
SliceParameters sliceParams =
561-
computeSliceParameters(builder, loc, valueToTile, tileSizes, map, lbs,
562-
ubs, subShapeSizes, omitPartialTileCheck);
574+
bool omitPartialTileCheck,
575+
ArrayRef<int64_t> domainSizes) {
576+
SliceParameters sliceParams = computeSliceParameters(
577+
builder, loc, valueToTile, tileSizes, map, lbs, ubs, subShapeSizes,
578+
omitPartialTileCheck, domainSizes);
563579
return materializeTiledShape(builder, loc, valueToTile, sliceParams);
564580
}
565581

566-
SliceParameters
567-
computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
568-
ArrayRef<OpFoldResult> tileSizes, AffineMap map,
569-
ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
570-
ArrayRef<OpFoldResult> subShapeSizes,
571-
bool omitPartialTileCheck) {
582+
SliceParameters computeSliceParameters(
583+
OpBuilder &builder, Location loc, Value valueToTile,
584+
ArrayRef<OpFoldResult> tileSizes, AffineMap map, ArrayRef<OpFoldResult> lbs,
585+
ArrayRef<OpFoldResult> ubs, ArrayRef<OpFoldResult> subShapeSizes,
586+
bool omitPartialTileCheck, ArrayRef<int64_t> domainSizes) {
572587
auto shapedType = dyn_cast<ShapedType>(valueToTile.getType());
573588
assert(shapedType && "only shaped types can be tiled");
574589
ArrayRef<int64_t> shape = shapedType.getShape();
@@ -585,7 +600,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
585600
// The offset & size computation below only handles the case when
586601
// the map is monotonically increasing, i.e. the min and max values are
587602
// attained at the lower and upper bounds of the iteration domain.
588-
if (!isTiled(m, tileSizes) || !m.isComponentWiseMonotonicallyIncreasing()) {
603+
if (!isTiled(m, tileSizes, domainSizes)) {
589604
sliceParams.offsets.push_back(builder.getIndexAttr(0));
590605
OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r);
591606
sliceParams.sizes.push_back(dim);
@@ -786,8 +801,9 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
786801
// subdomains explicit.
787802

788803
Type operandType = opOperand.get().getType();
789-
if (!isTiled(map, tileSizes) && !(isa<RankedTensorType>(operandType) &&
790-
linalgOp.isDpsInit(&opOperand))) {
804+
if (!isTiled(map, tileSizes, linalgOp.getStaticLoopRanges()) &&
805+
!(isa<RankedTensorType>(operandType) &&
806+
linalgOp.isDpsInit(&opOperand))) {
791807
allSliceParams.push_back(std::nullopt);
792808
LLVM_DEBUG(llvm::dbgs()
793809
<< ": not tiled: use shape: " << operandType << "\n");
@@ -797,7 +813,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp,
797813

798814
allSliceParams.push_back(computeSliceParameters(
799815
builder, loc, shapedOp, tileSizes, map, lbs, sizeBounds, subShapeSizes,
800-
omitPartialTileCheck));
816+
omitPartialTileCheck, linalgOp.getStaticLoopRanges()));
801817
}
802818

803819
return allSliceParams;

0 commit comments

Comments
 (0)