Skip to content

Commit 7916a3f

Browse files
committed
fixup! fixup! [mlir][tensor] Generalize/restrict GeneralizeOuterUnitDimsPackOpPattern
Raname and move getSimplifiedDimSizePair
1 parent 128c0e5 commit 7916a3f

File tree

3 files changed

+37
-27
lines changed

3 files changed

+37
-27
lines changed

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
6060
SmallVectorImpl<Value> &dynamicVec,
6161
SmallVectorImpl<int64_t> &staticVec);
6262

63+
/// Given OpFoldResult representing dim size value (*), generates a pair of
64+
/// sizes:
65+
/// * 1st result, static value, contains an int64_t dim size that can be used
66+
/// to build ShapedType (ShapedType::kDynamic is used for truly dynamic dims),
67+
/// * 2nd result, dynamic value, contains OpFoldResult encapsulating the
68+
/// actual dim size (either original or updated input value).
69+
/// For input sizes for which it is possible to extract a constant Attribute,
70+
/// replaces the original size value with an integer attribute (unless it's
71+
/// already a constant Attribute). The 1st return value also becomes the actual
72+
/// integer size (as opposed ShapedType::kDynamic).
73+
///
74+
/// (*) This hook is usually used when, given input sizes as OpFoldResult,
75+
/// it's required to generate two vectors:
76+
/// * sizes as int64_t to generate a shape,
77+
/// * sizes as OpFoldResult for sizes-like attribute.
78+
/// Please update this comment if you identify other use cases.
79+
std::pair<int64_t, OpFoldResult>
80+
getSimplifiedOfrAndStaticSizePair(OpFoldResult ofr, Builder &b);
81+
6382
/// Extract integer values from the assumed ArrayAttr of IntegerAttr.
6483
template <typename IntTy>
6584
SmallVector<IntTy> extractFromIntegerArrayAttr(Attribute attr) {

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,37 +1139,14 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11391139
return perm;
11401140
}
11411141

1142-
// A helper function to generate a dim-and-size pair for Ops like
1143-
// ExtractSliceOp that require both:
1144-
// * dims to specify the output shape, and
1145-
// * sizes for the sizes attribute (or similar).
1146-
// For dynamic sizes, if the corresponding size is a compile time constant:
1147-
// * the return size becomes the attribute encapsulating the known size, and
1148-
// * dim is updated from kDynamic to its actual known value.
1149-
static std::pair<int64_t, OpFoldResult>
1150-
getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, Builder &b) {
1151-
int64_t tileSizeForShape =
1152-
getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
1153-
1154-
OpFoldResult tileSizeOfrSimplified;
1155-
if (tileSizeForShape != ShapedType::kDynamic) {
1156-
tileSizeOfrSimplified = b.getIndexAttr(tileSizeForShape);
1157-
} else {
1158-
tileSizeOfrSimplified = tileSizeOfr;
1159-
}
1160-
1161-
return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
1162-
tileSizeOfrSimplified);
1163-
}
1164-
11651142
LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11661143
tensor::PackOp packOp, PatternRewriter &rewriter) const {
11671144
// TODO: support the case that outer dimensions are not all 1s. A
11681145
// tensor.expand_shape will be generated in this case.
1169-
if (llvm::any_of(packOp.getTiledOuterDims(),
1146+
if (llvm::any_of(packOp.getAllOuterDims(),
11701147
[](int64_t dim) { return dim != 1; })) {
11711148
return rewriter.notifyMatchFailure(
1172-
packOp, "require the tiled outer dimensions of the result are all 1s");
1149+
packOp, "not all outer dimensions of the result are 1s");
11731150
}
11741151

11751152
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
@@ -1202,7 +1179,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12021179
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
12031180
if (dimAndTileMapping.count(i)) {
12041181
auto [tileSize, tileSizeOfr] =
1205-
getSimplifiedDimSizePair(dimAndTileMapping[i], rewriter);
1182+
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
12061183
extractSliceSizes.push_back(tileSizeOfr);
12071184
outputShapeForExtractSlice.push_back(tileSize);
12081185
}
@@ -1254,7 +1231,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12541231

12551232
for (auto tileSize : packOp.getMixedTiles()) {
12561233
auto [tileSizeStatic, tileSizeOfr] =
1257-
getSimplifiedDimSizePair(tileSize, rewriter);
1234+
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter);
12581235
writeSizes.push_back(tileSizeOfr);
12591236
writeShape.push_back(tileSizeStatic);
12601237
}

mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ void dispatchIndexOpFoldResult(OpFoldResult ofr,
5858
staticVec.push_back(ShapedType::kDynamic);
5959
}
6060

61+
std::pair<int64_t, OpFoldResult>
62+
getSimplifiedOfrAndStaticSizePair(OpFoldResult tileSizeOfr, Builder &b) {
63+
int64_t tileSizeForShape =
64+
getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
65+
66+
OpFoldResult tileSizeOfrSimplified =
67+
(tileSizeForShape != ShapedType::kDynamic)
68+
? b.getIndexAttr(tileSizeForShape)
69+
: tileSizeOfr;
70+
71+
return std::pair<int64_t, OpFoldResult>(tileSizeForShape,
72+
tileSizeOfrSimplified);
73+
}
74+
6175
void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
6276
SmallVectorImpl<Value> &dynamicVec,
6377
SmallVectorImpl<int64_t> &staticVec) {

0 commit comments

Comments
 (0)