Skip to content

Commit 2d9bc62

Browse files
committed
fixup! fixup! fixup! [mlir][tensor] Generalize/restrict GeneralizeOuterUnitDimsPackOpPattern
Minor tweak
1 parent 5efa829 commit 2d9bc62

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2828
#include "mlir/IR/AffineExpr.h"
2929
#include "mlir/IR/Matchers.h"
30-
#include "mlir/IR/PatternMatch.h"
3130
#include "mlir/Pass/Pass.h"
3231
#include "mlir/Support/LLVM.h"
3332
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -1194,27 +1193,27 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11941193
// 2. Transpose the tile to match the inner tile order:
11951194
// %init = tensor.empty()
11961195
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init)
1197-
// NOTE: Outer dims are 1 and hence effectively ignored.
1196+
// NOTE: Outer dims are 1 and hence effectively ignored.
11981197
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
11991198
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
12001199

12011200
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
12021201
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
12031202

12041203
// 2.1 Create tensor.empty (init value for TransposeOp)
1205-
SmallVector<OpFoldResult> transShapeForEmptyOpDynamic;
1204+
SmallVector<OpFoldResult> transShapeForEmptyOp;
12061205

12071206
// Acquire tensor shape required to create EmptyOp. This will match the inner
12081207
// tile sizes.
12091208
size_t idx = numTiles;
12101209
while (idx != 0) {
1211-
transShapeForEmptyOpDynamic.push_back(extractSliceSizes[srcRank - idx]);
1210+
transShapeForEmptyOp.push_back(extractSliceSizes[srcRank - idx]);
12121211
idx--;
12131212
}
12141213

1215-
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOpDynamic, perm);
1216-
Value empty = rewriter.create<tensor::EmptyOp>(
1217-
loc, transShapeForEmptyOpDynamic, elemType);
1214+
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, perm);
1215+
Value empty =
1216+
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmptyOp, elemType);
12181217

12191218
// 2.2 Create linalg.transpose
12201219
auto transposedOp =

0 commit comments

Comments
 (0)