Skip to content

Commit 198cdaf

Browse files
committed
fixup! [mlir][tensor] Extend the logic to generalise tensor.pack
* Use `OpFoldResult` for the output shape of tensor::EmptyOp (to simplify the invocation) * Rename `readShape` as `readShapeForExtractSlice` * Rename `transpShape` as `transShapeForEmpty`
1 parent affd836 commit 198cdaf

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,12 +1177,15 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11771177
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
11781178
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
11791179
SmallVector<OpFoldResult> readSizes;
1180-
SmallVector<int64_t> readShape;
1180+
SmallVector<OpFoldResult> transShapeForEmpty;
1181+
SmallVector<int64_t> readShapeForExtractSlice;
11811182
for (auto i : llvm::seq<unsigned>(0, srcRank)) {
11821183
if (dimAndTileMapping.count(i)) {
1183-
readShape.push_back(getConstantIntValue(dimAndTileMapping[i])
1184-
.value_or(ShapedType::kDynamic));
1184+
readShapeForExtractSlice.push_back(
1185+
getConstantIntValue(dimAndTileMapping[i])
1186+
.value_or(ShapedType::kDynamic));
11851187
readSizes.push_back(dimAndTileMapping[i]);
1188+
transShapeForEmpty.push_back(dimAndTileMapping[i]);
11861189
continue;
11871190
}
11881191
if (ShapedType::isDynamic(inputShape[i])) {
@@ -1191,12 +1194,14 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11911194
} else {
11921195
readSizes.push_back(rewriter.getIndexAttr(inputShape[i]));
11931196
}
1194-
if (inputShape[i] != 1)
1195-
readShape.push_back(inputShape[i]);
1197+
if (inputShape[i] != 1) {
1198+
readShapeForExtractSlice.push_back(inputShape[i]);
1199+
transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i]));
1200+
}
11961201
}
11971202

11981203
Type elemType = packOp.getSourceType().getElementType();
1199-
auto readType = RankedTensorType::get(readShape, elemType);
1204+
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType);
12001205

12011206
Value tile = rewriter.create<tensor::ExtractSliceOp>(
12021207
loc, readType, input, readOffsets, readSizes, readStrides);
@@ -1208,8 +1213,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12081213
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
12091214
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
12101215

1211-
SmallVector<int64_t> transpShape = readShape;
1212-
applyPermutationToVector<int64_t>(transpShape, perm);
1216+
applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm);
12131217

12141218
// If there's a tile with a dynamic size, retrieve its size. ATM only 1
12151219
// dynamic tile is allowed.
@@ -1222,10 +1226,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12221226
}
12231227

12241228
Value empty =
1225-
ShapedType::isDynamicShape(cast<ShapedType>(input.getType()).getShape())
1226-
? rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType,
1227-
dynDimSize)
1228-
: rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
1229+
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType);
12291230
auto transposedOp =
12301231
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm);
12311232

0 commit comments

Comments
 (0)