Skip to content

Commit 744a291

Browse files
committed
Fixed all issues pointed out by HanHan except factoring in StripMineTensorType
1 parent c33642b commit 744a291

File tree

5 files changed

+322
-339
lines changed

5 files changed

+322
-339
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ computeTransposedType(RankedTensorType rankedTensorType,
3838
/// i.e. for a pack from an ABCD layout to an ABCDba:
3939
/// The packed shape would be ABCDba.
4040
/// The pre-permutation shape would be AaBbCD.
41-
SmallVector<int64_t> getPackInverseDestPermutation(PackOp packOp);
41+
SmallVector<int64_t> getPackUnPackInverseDestPerm(
42+
std::variant<tensor::PackOp, tensor::UnPackOp> packOp);
4243

4344
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
4445
/// source tensor or inserts the source tensor into a destination tensor with

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
237237
PackingMetadata packingMetadata = computePackingMetadata(
238238
packedTensorType.getRank(), packOp.getInnerDimsPos());
239239
SmallVector<int64_t> packedToStripMinedShapePerm =
240-
tensor::getPackInverseDestPermutation(packOp);
240+
tensor::getPackUnPackInverseDestPerm(packOp);
241241

242242
// 3. Compute the stripMinedShape: this is the packed shape before any outer
243243
// or inner permutations have been applied.

0 commit comments

Comments
 (0)