Skip to content

Commit 9f6a1dd

Browse files
authored
[mlir][tensor] Introduce FoldTensorCastUnPackOp (#121393)
This patch specializes `FoldTensorCastProducerOp` for `tensor::UnPackOp` by introducing a dedicated pattern: `FoldTensorCastUnPackOp`. This mirrors a similar update made for `tensor::PackOp` in #114559. Below is the updated rationale tailored to `tensor::UnPackOp`. ISSUE DESCRIPTION Currently, `FoldTensorCastProducerOp` incorrectly folds the following: ```mlir %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> // Note: `%c8` and `?`. %unpack = tensor.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32> ``` as: ```mlir // Note: `%c8` and `8`. %unpack = tensor.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %res : tensor<1x1x8x1xi32> -> tensor<7x?xi32> ``` This triggers an Op verification failure because the folder does not update the inner tile sizes in the unpack Op. This patch addresses the issue by ensuring proper handling of inner tile sizes. ADDITIONAL CHANGES * invalid.mlir: Fixed a typo. * TensorOps.cpp: * Removed unnecessary `(void)tileSize`. * Added comments following the discussion in PR #115772. * Made minor updates to `FoldTensorCastPackOp` for consistency with the newly introduced `FoldTensorCastUnPackOp`. * Tensor/canonicalize.mlir: Ensured consistent usage of `test_attr` (e.g., replaced mixed use of `test_attr` and `some_attr`).
1 parent 8b23ebb commit 9f6a1dd

File tree

3 files changed

+123
-27
lines changed

3 files changed

+123
-27
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 99 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
47954795
return newOperands;
47964796
}
47974797

4798+
// Given the (potentially) updated packed type, `newPackedTy`, generates an
4799+
// updated mixed-tile-sizes attribute. A tile size is updated only
4800+
// when:
4801+
// * a dim from newPackedTy is static, and
4802+
// * the corresponding size from mixedTiles is still dynamic.
4803+
// Otherwise, the original tile size is preserved.
4804+
// Note - packed-type-dim and mixed-tile-size should always match!
4805+
static SmallVector<OpFoldResult>
4806+
getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
4807+
SmallVector<OpFoldResult> mixedTiles) {
4808+
SmallVector<OpFoldResult> newMixedTileSizes;
4809+
for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4810+
.getShape()
4811+
.take_back(mixedTiles.size()),
4812+
mixedTiles)) {
4813+
int64_t shape = std::get<0>(it);
4814+
if (shape == ShapedType::kDynamic) {
4815+
newMixedTileSizes.push_back(std::get<1>(it));
4816+
continue;
4817+
}
4818+
4819+
// If the current result dim is static, update the dynamic mixed-size
4820+
// (provided the original value is dynamic).
4821+
OpFoldResult tile = std::get<1>(it);
4822+
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4823+
// Already a constant
4824+
newMixedTileSizes.push_back(tile);
4825+
} else {
4826+
assert(getConstantIntValue(tile).value() == shape &&
4827+
"tile size and dim size don't match!");
4828+
newMixedTileSizes.push_back(
4829+
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4830+
}
4831+
}
4832+
4833+
return newMixedTileSizes;
4834+
}
4835+
47984836
/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
47994837
/// `tensor.cast` has source that is more static than the consuming op.
48004838
///
@@ -4821,31 +4859,13 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48214859
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
48224860

48234861
// Get the updated mixed-tile-sizes attribute.
4824-
SmallVector<OpFoldResult> newMixedTileSizes;
4825-
for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
4826-
.getShape()
4827-
.take_back(op.getMixedTiles().size()),
4828-
op.getMixedTiles())) {
4829-
int64_t shape = std::get<0>(it);
4830-
if (shape == ShapedType::kDynamic) {
4831-
newMixedTileSizes.push_back(std::get<1>(it));
4832-
continue;
4833-
}
4834-
4835-
if (Attribute attr =
4836-
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4837-
// Already a constant
4838-
newMixedTileSizes.push_back(std::get<1>(it));
4839-
} else {
4840-
int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
4841-
assert(tileSize == shape && "tile size and dim size don't match!");
4842-
(void)tileSize;
4843-
newMixedTileSizes.push_back(
4844-
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4845-
}
4846-
}
4862+
SmallVector<OpFoldResult> newMixedTileSizes =
4863+
getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
48474864

48484865
// Clone op.
4866+
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
4867+
// this point. However, in practice, we use them for things that we'd like
4868+
// to preserve. Implement a better abstraction.
48494869
PackOp newOp = rewriter.create<PackOp>(
48504870
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
48514871
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
@@ -4865,6 +4885,59 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48654885
}
48664886
};
48674887

4888+
/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
4889+
/// `tensor.cast` has source that is more static than the consuming op.
4890+
///
4891+
/// Example:
4892+
/// ```mlir
4893+
/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
4894+
/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
4895+
/// ```
4896+
///
4897+
/// folds into:
4898+
///
4899+
/// ```mlir
4900+
/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4901+
/// ```
4902+
struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
4903+
using OpRewritePattern<UnPackOp>::OpRewritePattern;
4904+
4905+
LogicalResult matchAndRewrite(UnPackOp op,
4906+
PatternRewriter &rewriter) const override {
4907+
if (!foldTensorCastPrecondition(op))
4908+
return failure();
4909+
4910+
SmallVector<Type> newResultTypes(op->getResultTypes());
4911+
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4912+
Value sourceTensor = newOperands[0];
4913+
4914+
// Get the updated mixed-tile-sizes attribute.
4915+
SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
4916+
rewriter, sourceTensor.getType(), op.getMixedTiles());
4917+
4918+
// Clone op.
4919+
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
4920+
// this point. However, in practice, we use them for things that we'd like
4921+
// to preserve. Implement a better abstraction.
4922+
UnPackOp newOp = rewriter.create<UnPackOp>(
4923+
op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
4924+
newMixedTileSizes, op.getOuterDimsPerm());
4925+
newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
4926+
4927+
// Replace op.
4928+
Value oldResult = op.getResult();
4929+
Value newResult = newOp.getResult();
4930+
Value replacement = (newResult.getType() != oldResult.getType())
4931+
? rewriter.create<tensor::CastOp>(
4932+
op->getLoc(), oldResult.getType(), newResult)
4933+
: newResult;
4934+
4935+
rewriter.replaceOp(op, {replacement});
4936+
4937+
return success();
4938+
}
4939+
};
4940+
48684941
/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
48694942
/// the `tensor.cast` has source that is more static than the consuming op.
48704943
///
@@ -4890,7 +4963,8 @@ struct FoldTensorCastProducerOp
48904963
PatternRewriter &rewriter) const override {
48914964

48924965
// Reject tensor::PackOp - there's dedicated pattern for that instead.
4893-
if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
4966+
if (!foldTensorCastPrecondition(op) ||
4967+
isa<tensor::PackOp, tensor::UnPackOp>(*op))
48944968
return failure();
48954969

48964970
SmallVector<Type> newResultTypes(op->getResultTypes());
@@ -4923,6 +4997,7 @@ struct FoldTensorCastProducerOp
49234997
void TensorDialect::getCanonicalizationPatterns(
49244998
RewritePatternSet &results) const {
49254999
results.add<FoldTensorCastPackOp>(getContext());
5000+
results.add<FoldTensorCastUnPackOp>(getContext());
49265001
results.add<FoldTensorCastProducerOp>(getContext());
49275002
}
49285003

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
27862786
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
27872787
return %0#1 : index
27882788
}
2789+
27892790
// -----
27902791

27912792
// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
@@ -2794,7 +2795,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
27942795
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
27952796
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
27962797
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]]
2797-
// CHECK-SAME: some_attr
2798+
// CHECK-SAME: test_attr
27982799
// CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
27992800
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
28002801
func.func @fold_cast_pack_dynamic_tile_size(
@@ -2807,13 +2808,33 @@ func.func @fold_cast_pack_dynamic_tile_size(
28072808
%pack = tensor.pack %src padding_value(%pad : i32)
28082809
inner_dims_pos = [0, 1]
28092810
inner_tiles = [%c8, 1]
2810-
into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
2811+
into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
28112812
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
28122813
return %res : tensor<1x1x8x1xi32>
28132814
}
28142815

28152816
// -----
28162817

2818+
// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
2819+
// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
2820+
// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
2821+
// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
2822+
// CHECK: return %[[RES]] : tensor<7x?xi32>
2823+
func.func @fold_cast_unpack_dynamic_tile_size(
2824+
%src: tensor<1x1x8x1xi32>,
2825+
%res: tensor<7x?xi32>) -> tensor<7x?xi32> {
2826+
2827+
%cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
2828+
%c8 = arith.constant 8 : index
2829+
%unpack = tensor.unpack %cast
2830+
inner_dims_pos = [0, 1]
2831+
inner_tiles = [%c8, 1]
2832+
into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
2833+
return %unpack : tensor<7x?xi32>
2834+
}
2835+
2836+
// -----
2837+
28172838
// CHECK-LABEL: func.func @pack_dont_drop_attributes(
28182839
// CHECK: tensor.pack {{.*}} {test_attr}
28192840
func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> {

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor
699699

700700
// -----
701701

702-
func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
702+
func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> {
703703
// expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}}
704704
%0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32>
705705
return %0 : tensor<256x128xf32>

0 commit comments

Comments
 (0)