Skip to content

Commit 128c0e5

Browse files
committed
fixup! [mlir][tensor] Generalize/restrict GeneralizeOuterUnitDimsPackOpPattern
SKip calculating static shapes for EmptyOp
1 parent 49cf5cc commit 128c0e5

File tree

1 file changed

+5
-15
lines changed

1 file changed

+5
-15
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,13 +1147,13 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11471147
// * the return size becomes the attribute encapsulating the known size, and
11481148
// * dim is updated from kDynamic to its actual known value.
11491149
static std::pair<int64_t, OpFoldResult>
1150-
getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, PatternRewriter &rewriter) {
1150+
getSimplifiedDimSizePair(OpFoldResult tileSizeOfr, Builder &b) {
11511151
int64_t tileSizeForShape =
11521152
getConstantIntValue(tileSizeOfr).value_or(ShapedType::kDynamic);
11531153

11541154
OpFoldResult tileSizeOfrSimplified;
11551155
if (tileSizeForShape != ShapedType::kDynamic) {
1156-
tileSizeOfrSimplified = rewriter.getIndexAttr(tileSizeForShape);
1156+
tileSizeOfrSimplified = b.getIndexAttr(tileSizeForShape);
11571157
} else {
11581158
tileSizeOfrSimplified = tileSizeOfr;
11591159
}
@@ -1226,28 +1226,18 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
12261226

12271227
// 2.1 Create tensor.empty (init value for TransposeOp)
12281228
SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
1229-
SmallVector<int64_t> transShapeForEmptyOpStatic;
12301229

12311230
// Acquire tensor shape required to create EmptyOp. This will match the inner
1232-
// tile sizes, but the actual data format will depend on whether the tile
1233-
// sizes are static or dynamic (each case leads to a different builder for
1234-
// EmptyOp). Conservatively, prepare for both scenarios.
1231+
// tile sizes.
12351232
size_t idx = numTiles;
12361233
while (idx != 0) {
12371234
transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
1238-
transShapeForEmptyOpStatic.push_back(
1239-
outputShapeForExtractSlice[numTiles - idx]);
12401235
idx--;
12411236
}
12421237

1243-
applyPermutationToVector<int64_t>(transShapeForEmptyOpStatic, perm);
12441238
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
1245-
1246-
Value empty = ShapedType::isDynamicShape(transShapeForEmptyOpStatic)
1247-
? rewriter.create<tensor::EmptyOp>(
1248-
loc, transShapeForEmptyOpDynamic, elemType)
1249-
: rewriter.create<tensor::EmptyOp>(
1250-
loc, transShapeForEmptyOpStatic, elemType);
1239+
Value empty = rewriter.create<tensor::EmptyOp>(
1240+
loc, transShapeForEmptyOpDynamic, elemType);
12511241

12521242
// 2.2 Create linalg.transpose
12531243
auto transposedOp =

0 commit comments

Comments
 (0)