@@ -4698,6 +4698,111 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
4698
4698
// ===----------------------------------------------------------------------===//
4699
4699
// Common Canonicalizers and Folders.
4700
4700
// ===----------------------------------------------------------------------===//
4701
+ bool foldTensorCastPrecondition (DestinationStyleOpInterface op) {
4702
+ // 1. InsertSliceOp has its own logic about folding tensor.cast ops.
4703
+ // 2. Exclude DPS ops that are also LoopLike from this interface as they
4704
+ // might need special handling of attached regions.
4705
+ if (isa<InsertSliceOp>(op.getOperation ()) ||
4706
+ isa<LoopLikeOpInterface>(op.getOperation ()))
4707
+ return false ;
4708
+
4709
+ // If no operand comes from a tensor::CastOp and can be folded then fail.
4710
+ bool hasTensorCastOperand =
4711
+ llvm::any_of (op->getOpOperands (), [&](OpOperand &opOperand) {
4712
+ if (llvm::isa<BlockArgument>(opOperand.get ()))
4713
+ return false ;
4714
+ auto castOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4715
+ return castOp && canFoldIntoConsumerOp (castOp);
4716
+ });
4717
+
4718
+ return hasTensorCastOperand;
4719
+ }
4720
+
4721
+ static SmallVector<Value> getNewOperands (DestinationStyleOpInterface op,
4722
+ SmallVector<Type> &newResTy) {
4723
+ SmallVector<Value> newOperands;
4724
+ newOperands.reserve (op->getNumOperands ());
4725
+
4726
+ // Assumes that the result has dpsInits followed by nonDpsInits.
4727
+ int64_t dpsInitIdx = 0 ;
4728
+ for (OpOperand &opOperand : op->getOpOperands ()) {
4729
+ auto tensorCastOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4730
+ bool fold = canFoldIntoConsumerOp (tensorCastOp);
4731
+ newOperands.push_back (fold ? tensorCastOp.getOperand () : opOperand.get ());
4732
+ if (op.isDpsInit (&opOperand) &&
4733
+ !llvm::isa<MemRefType>(newOperands.back ().getType ()))
4734
+ newResTy[dpsInitIdx++] = newOperands.back ().getType ();
4735
+ }
4736
+ return newOperands;
4737
+ }
4738
+
4739
+ // / Folds a tensor.cast op into a consuming tensor::PackOp op if the
4740
+ // / `tensor.cast` has source that is more static than the consuming op.
4741
+ // /
4742
+ // / Example:
4743
+ // / ```mlir
4744
+ // / %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
4745
+ // / %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
4746
+ // / ```
4747
+ // /
4748
+ // / folds into:
4749
+ // /
4750
+ // / ```mlir
4751
+ // / %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
4752
+ // / ```
4753
+ struct FoldTensorCastPackOp : public OpRewritePattern <PackOp> {
4754
+ using OpRewritePattern<PackOp>::OpRewritePattern;
4755
+
4756
+ LogicalResult matchAndRewrite (PackOp op,
4757
+ PatternRewriter &rewriter) const override {
4758
+ if (!foldTensorCastPrecondition (op))
4759
+ return failure ();
4760
+
4761
+ SmallVector<Type> newResultTypes (op->getResultTypes ());
4762
+ SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4763
+
4764
+ // Get the updated mixed-tile-sizes attribute.
4765
+ SmallVector<OpFoldResult> newMixedTileSizes;
4766
+ for (auto it : llvm::zip (cast<ShapedType>(newResultTypes[0 ])
4767
+ .getShape ()
4768
+ .take_back (op.getMixedTiles ().size ()),
4769
+ op.getMixedTiles ())) {
4770
+ int64_t shape = std::get<0 >(it);
4771
+ if (shape == ShapedType::kDynamic ) {
4772
+ newMixedTileSizes.push_back (std::get<1 >(it));
4773
+ continue ;
4774
+ }
4775
+
4776
+ if (Attribute attr =
4777
+ llvm::dyn_cast_if_present<Attribute>(std::get<1 >(it))) {
4778
+ // Already a constant
4779
+ newMixedTileSizes.push_back (std::get<1 >(it));
4780
+ } else {
4781
+ int64_t tileSize = getConstantIntValue (std::get<1 >(it)).value ();
4782
+ assert (tileSize == shape && " tile size and dim size don't match!" );
4783
+ newMixedTileSizes.push_back (
4784
+ (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4785
+ }
4786
+ }
4787
+
4788
+ // Clone op.
4789
+ PackOp newOp = rewriter.create <PackOp>(
4790
+ op.getLoc (), newOperands[0 ], newOperands[1 ], op.getInnerDimsPos (),
4791
+ newMixedTileSizes, op.getPaddingValue (), op.getOuterDimsPerm ());
4792
+
4793
+ // Replace op.
4794
+ Value oldResult = op.getResult ();
4795
+ Value newResult = newOp.getResult ();
4796
+ Value replacement = (newResult.getType () != oldResult.getType ())
4797
+ ? rewriter.create <tensor::CastOp>(
4798
+ op->getLoc (), oldResult.getType (), newResult)
4799
+ : newResult;
4800
+
4801
+ rewriter.replaceOp (op, {replacement});
4802
+
4803
+ return success ();
4804
+ }
4805
+ };
4701
4806
4702
4807
// / Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4703
4808
// / the `tensor.cast` has source that is more static than the consuming op.
@@ -4722,42 +4827,17 @@ struct FoldTensorCastProducerOp
4722
4827
4723
4828
LogicalResult matchAndRewrite (DestinationStyleOpInterface op,
4724
4829
PatternRewriter &rewriter) const override {
4725
- // InsertSliceOp has its own logic about folding tensor.cast ops.
4726
- if (isa<InsertSliceOp>(op.getOperation ()))
4727
- return failure ();
4728
4830
4729
- // Exclude DPS ops that are also LoopLike from this interface as they
4730
- // might need special handling of attached regions.
4731
- if (isa<LoopLikeOpInterface>(op.getOperation ()))
4831
+ // Reject tensor::PackOp - there's dedicated pattern for that instead.
4832
+ if (!foldTensorCastPrecondition (op) || dyn_cast<tensor::PackOp>(*op))
4732
4833
return failure ();
4733
4834
4734
- // If no operand comes from a tensor::CastOp and can be folded then fail.
4735
- bool hasTensorCastOperand =
4736
- llvm::any_of (op->getOpOperands (), [&](OpOperand &opOperand) {
4737
- if (llvm::isa<BlockArgument>(opOperand.get ()))
4738
- return false ;
4739
- auto castOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4740
- return castOp && canFoldIntoConsumerOp (castOp);
4741
- });
4742
- if (!hasTensorCastOperand)
4743
- return failure ();
4835
+ SmallVector<Type> newResultTypes (op->getResultTypes ());
4836
+ SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4744
4837
4745
- SmallVector<Type, 4 > newResultTypes (op->getResultTypes ());
4746
- SmallVector<Value, 4 > newOperands;
4747
- newOperands.reserve (op->getNumOperands ());
4748
- // Assumes that the result has dpsInits followed by nonDpsInits.
4749
- int64_t dpsInitIdx = 0 ;
4750
- for (OpOperand &opOperand : op->getOpOperands ()) {
4751
- auto tensorCastOp = opOperand.get ().getDefiningOp <tensor::CastOp>();
4752
- bool fold = canFoldIntoConsumerOp (tensorCastOp);
4753
- newOperands.push_back (fold ? tensorCastOp.getOperand () : opOperand.get ());
4754
- if (op.isDpsInit (&opOperand) &&
4755
- !llvm::isa<MemRefType>(newOperands.back ().getType ()))
4756
- newResultTypes[dpsInitIdx++] = newOperands.back ().getType ();
4757
- }
4838
+ // Clone op
4839
+ auto newOp = clone (rewriter, op, newResultTypes, newOperands);
4758
4840
4759
- // Clone op.
4760
- Operation *newOp = clone (rewriter, op, newResultTypes, newOperands);
4761
4841
SmallVector<Value, 4 > replacements;
4762
4842
replacements.reserve (newOp->getNumResults ());
4763
4843
for (auto [oldResult, newResult] :
@@ -4781,6 +4861,7 @@ struct FoldTensorCastProducerOp
4781
4861
4782
4862
void TensorDialect::getCanonicalizationPatterns (
4783
4863
RewritePatternSet &results) const {
4864
+ results.add <FoldTensorCastPackOp>(getContext ());
4784
4865
results.add <FoldTensorCastProducerOp>(getContext ());
4785
4866
}
4786
4867
0 commit comments