Skip to content

Commit 9b9369e

Browse files
authored
[mlir][tensor] Improve FoldTensorCastProducerOp (dynamic shapes) (#114559)
Currently, `FoldTensorCastProducerOp` incorrectly folds the following: ```mlir %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32> %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> ``` as (note the static trailing dim in the result and dynamic tile dimension that corresponds to that): ```mlir %res = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x8x1xi32> ``` This triggers an Op verification failure and is due to the fact that the folder does not update the inner tile sizes in the pack Op. This PR addresses that. Note, supporting other Ops with size-like attributes is left as a TODO.
1 parent a993dfc commit 9b9369e

File tree

2 files changed

+133
-33
lines changed

2 files changed

+133
-33
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 112 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4698,6 +4698,111 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
46984698
//===----------------------------------------------------------------------===//
46994699
// Common Canonicalizers and Folders.
47004700
//===----------------------------------------------------------------------===//
4701+
bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
4702+
// 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4703+
// 2. Exclude DPS ops that are also LoopLike from this interface as they
4704+
// might need special handling of attached regions.
4705+
if (isa<InsertSliceOp>(op.getOperation()) ||
4706+
isa<LoopLikeOpInterface>(op.getOperation()))
4707+
return false;
4708+
4709+
// If no operand comes from a tensor::CastOp and can be folded then fail.
4710+
bool hasTensorCastOperand =
4711+
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4712+
if (llvm::isa<BlockArgument>(opOperand.get()))
4713+
return false;
4714+
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4715+
return castOp && canFoldIntoConsumerOp(castOp);
4716+
});
4717+
4718+
return hasTensorCastOperand;
4719+
}
4720+
4721+
static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
4722+
SmallVector<Type> &newResTy) {
4723+
SmallVector<Value> newOperands;
4724+
newOperands.reserve(op->getNumOperands());
4725+
4726+
// Assumes that the result has dpsInits followed by nonDpsInits.
4727+
int64_t dpsInitIdx = 0;
4728+
for (OpOperand &opOperand : op->getOpOperands()) {
4729+
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4730+
bool fold = canFoldIntoConsumerOp(tensorCastOp);
4731+
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4732+
if (op.isDpsInit(&opOperand) &&
4733+
!llvm::isa<MemRefType>(newOperands.back().getType()))
4734+
newResTy[dpsInitIdx++] = newOperands.back().getType();
4735+
}
4736+
return newOperands;
4737+
}
4738+
4739+
/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
4740+
/// `tensor.cast` has source that is more static than the consuming op.
4741+
///
4742+
/// Example:
4743+
/// ```mlir
4744+
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4745+
/// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
4746+
/// ```
4747+
///
4748+
/// folds into:
4749+
///
4750+
/// ```mlir
4751+
/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
4752+
/// ```
4753+
struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
4754+
using OpRewritePattern<PackOp>::OpRewritePattern;
4755+
4756+
LogicalResult matchAndRewrite(PackOp op,
4757+
PatternRewriter &rewriter) const override {
4758+
if (!foldTensorCastPrecondition(op))
4759+
return failure();
4760+
4761+
SmallVector<Type> newResultTypes(op->getResultTypes());
4762+
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
4763+
4764+
// Get the updated mixed-tile-sizes attribute.
4765+
SmallVector<OpFoldResult> newMixedTileSizes;
4766+
for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
4767+
.getShape()
4768+
.take_back(op.getMixedTiles().size()),
4769+
op.getMixedTiles())) {
4770+
int64_t shape = std::get<0>(it);
4771+
if (shape == ShapedType::kDynamic) {
4772+
newMixedTileSizes.push_back(std::get<1>(it));
4773+
continue;
4774+
}
4775+
4776+
if (Attribute attr =
4777+
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4778+
// Already a constant
4779+
newMixedTileSizes.push_back(std::get<1>(it));
4780+
} else {
4781+
int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
4782+
assert(tileSize == shape && "tile size and dim size don't match!");
4783+
newMixedTileSizes.push_back(
4784+
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4785+
}
4786+
}
4787+
4788+
// Clone op.
4789+
PackOp newOp = rewriter.create<PackOp>(
4790+
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
4791+
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
4792+
4793+
// Replace op.
4794+
Value oldResult = op.getResult();
4795+
Value newResult = newOp.getResult();
4796+
Value replacement = (newResult.getType() != oldResult.getType())
4797+
? rewriter.create<tensor::CastOp>(
4798+
op->getLoc(), oldResult.getType(), newResult)
4799+
: newResult;
4800+
4801+
rewriter.replaceOp(op, {replacement});
4802+
4803+
return success();
4804+
}
4805+
};
47014806

47024807
/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
47034808
/// the `tensor.cast` has source that is more static than the consuming op.
@@ -4722,42 +4827,17 @@ struct FoldTensorCastProducerOp
47224827

47234828
LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
47244829
PatternRewriter &rewriter) const override {
4725-
// InsertSliceOp has its own logic about folding tensor.cast ops.
4726-
if (isa<InsertSliceOp>(op.getOperation()))
4727-
return failure();
47284830

4729-
// Exclude DPS ops that are also LoopLike from this interface as they
4730-
// might need special handling of attached regions.
4731-
if (isa<LoopLikeOpInterface>(op.getOperation()))
4831+
// Reject tensor::PackOp - there's dedicated pattern for that instead.
4832+
if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
47324833
return failure();
47334834

4734-
// If no operand comes from a tensor::CastOp and can be folded then fail.
4735-
bool hasTensorCastOperand =
4736-
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
4737-
if (llvm::isa<BlockArgument>(opOperand.get()))
4738-
return false;
4739-
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4740-
return castOp && canFoldIntoConsumerOp(castOp);
4741-
});
4742-
if (!hasTensorCastOperand)
4743-
return failure();
4835+
SmallVector<Type> newResultTypes(op->getResultTypes());
4836+
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
47444837

4745-
SmallVector<Type, 4> newResultTypes(op->getResultTypes());
4746-
SmallVector<Value, 4> newOperands;
4747-
newOperands.reserve(op->getNumOperands());
4748-
// Assumes that the result has dpsInits followed by nonDpsInits.
4749-
int64_t dpsInitIdx = 0;
4750-
for (OpOperand &opOperand : op->getOpOperands()) {
4751-
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
4752-
bool fold = canFoldIntoConsumerOp(tensorCastOp);
4753-
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
4754-
if (op.isDpsInit(&opOperand) &&
4755-
!llvm::isa<MemRefType>(newOperands.back().getType()))
4756-
newResultTypes[dpsInitIdx++] = newOperands.back().getType();
4757-
}
4838+
// Clone op
4839+
auto newOp = clone(rewriter, op, newResultTypes, newOperands);
47584840

4759-
// Clone op.
4760-
Operation *newOp = clone(rewriter, op, newResultTypes, newOperands);
47614841
SmallVector<Value, 4> replacements;
47624842
replacements.reserve(newOp->getNumResults());
47634843
for (auto [oldResult, newResult] :
@@ -4781,6 +4861,7 @@ struct FoldTensorCastProducerOp
47814861

47824862
void TensorDialect::getCanonicalizationPatterns(
47834863
RewritePatternSet &results) const {
4864+
results.add<FoldTensorCastPackOp>(getContext());
47844865
results.add<FoldTensorCastProducerOp>(getContext());
47854866
}
47864867

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2718,18 +2718,37 @@ func.func @dim_out_of_bounds() -> vector<7xi32> {
27182718

27192719
// -----
27202720

2721-
// CHECK-LABEL: func.func @test_destination_multiple_result(
2721+
// CHECK-LABEL: func.func @fold_cast_multiple_results(
27222722
// CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xf32>,
27232723
// CHECK-SAME: %[[ARG2:.*]]: tensor<2x2xf32>) -> index {
27242724
// CHECK: %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>)
27252725
// CHECK-SAME: outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index
27262726
// CHECK: return %[[RES]]#1 : index
2727-
func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
2727+
func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index {
27282728
%cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor<?x2xf32>
27292729
%cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor<?x2xf32>
27302730
%0:2 = test.destination_style_op ins(%cast : tensor<?x2xf32>) outs(%cast_0 : tensor<?x2xf32>) -> tensor<?x2xf32>, index
27312731
return %0#1 : index
27322732
}
2733+
// -----
2734+
2735+
// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size
2736+
// CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>,
2737+
// CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>,
2738+
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
2739+
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
2740+
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
2741+
func.func @fold_cast_pack_dynamic_tile_size(
2742+
%dest: tensor<1x1x8x1xi32>,
2743+
%src: tensor<7x?xi32>,
2744+
%pad: i32) -> tensor<1x1x8x1xi32> {
2745+
2746+
%cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
2747+
%c8 = arith.constant 8 : index
2748+
%pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
2749+
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
2750+
return %res : tensor<1x1x8x1xi32>
2751+
}
27332752

27342753
// -----
27352754

0 commit comments

Comments
 (0)