@@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
4795
4795
return newOperands;
4796
4796
}
4797
4797
4798
+ // Given the (potentially) updated packed type, `newPackedTy`, generates an
4799
+ // updated mixed-tile-sizes attribute. A tile size is updated only
4800
+ // when:
4801
+ // * a dim from newPackedTy is static, and
4802
+ // * the corresponding size from mixedTiles is still dynamic.
4803
+ // Otherwise, the original tile size is preserved.
4804
+ // Note - packed-type-dim and mixed-tile-size should always match!
4805
+ static SmallVector<OpFoldResult>
4806
+ getNewMixedTileSizes (PatternRewriter &rewriter, Type newPackedTy,
4807
+ SmallVector<OpFoldResult> mixedTiles) {
4808
+ SmallVector<OpFoldResult> newMixedTileSizes;
4809
+ for (auto it : llvm::zip (cast<ShapedType>(newPackedTy)
4810
+ .getShape ()
4811
+ .take_back (mixedTiles.size ()),
4812
+ mixedTiles)) {
4813
+ int64_t shape = std::get<0 >(it);
4814
+ if (shape == ShapedType::kDynamic ) {
4815
+ newMixedTileSizes.push_back (std::get<1 >(it));
4816
+ continue ;
4817
+ }
4818
+
4819
+ // If the current result dim is static, update the dynamic mixed-size
4820
+ // (provided the original value is dynamic).
4821
+ OpFoldResult tile = std::get<1 >(it);
4822
+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4823
+ // Already a constant
4824
+ newMixedTileSizes.push_back (tile);
4825
+ } else {
4826
+ assert (getConstantIntValue (tile).value () == shape &&
4827
+ " tile size and dim size don't match!" );
4828
+ newMixedTileSizes.push_back (
4829
+ (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4830
+ }
4831
+ }
4832
+
4833
+ return newMixedTileSizes;
4834
+ }
4835
+
4798
4836
// / Folds a tensor.cast op into a consuming tensor::PackOp op if the
4799
4837
// / `tensor.cast` has source that is more static than the consuming op.
4800
4838
// /
@@ -4821,31 +4859,13 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
4821
4859
SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4822
4860
4823
4861
// Get the updated mixed-tile-sizes attribute.
4824
- SmallVector<OpFoldResult> newMixedTileSizes;
4825
- for (auto it : llvm::zip (cast<ShapedType>(newResultTypes[0 ])
4826
- .getShape ()
4827
- .take_back (op.getMixedTiles ().size ()),
4828
- op.getMixedTiles ())) {
4829
- int64_t shape = std::get<0 >(it);
4830
- if (shape == ShapedType::kDynamic ) {
4831
- newMixedTileSizes.push_back (std::get<1 >(it));
4832
- continue ;
4833
- }
4834
-
4835
- if (Attribute attr =
4836
- llvm::dyn_cast_if_present<Attribute>(std::get<1 >(it))) {
4837
- // Already a constant
4838
- newMixedTileSizes.push_back (std::get<1 >(it));
4839
- } else {
4840
- int64_t tileSize = getConstantIntValue (std::get<1 >(it)).value ();
4841
- assert (tileSize == shape && " tile size and dim size don't match!" );
4842
- (void )tileSize;
4843
- newMixedTileSizes.push_back (
4844
- (rewriter.getIntegerAttr (rewriter.getIndexType (), shape)));
4845
- }
4846
- }
4862
+ SmallVector<OpFoldResult> newMixedTileSizes =
4863
+ getNewMixedTileSizes (rewriter, newResultTypes[0 ], op.getMixedTiles ());
4847
4864
4848
4865
// Clone op.
4866
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4867
+ // this point. However, in practice, we use them for things that we'd like
4868
+ // to preserve. Implement a better abstraction.
4849
4869
PackOp newOp = rewriter.create <PackOp>(
4850
4870
op.getLoc (), newOperands[0 ], newOperands[1 ], op.getInnerDimsPos (),
4851
4871
newMixedTileSizes, op.getPaddingValue (), op.getOuterDimsPerm ());
@@ -4865,6 +4885,59 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
4865
4885
}
4866
4886
};
4867
4887
4888
+ // / Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
4889
+ // / `tensor.cast` has source that is more static than the consuming op.
4890
+ // /
4891
+ // / Example:
4892
+ // / ```mlir
4893
+ // / %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
4894
+ // / %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
4895
+ // / ```
4896
+ // /
4897
+ // / folds into:
4898
+ // /
4899
+ // / ```mlir
4900
+ // / %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4901
+ // / ```
4902
+ struct FoldTensorCastUnPackOp : public OpRewritePattern <UnPackOp> {
4903
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
4904
+
4905
+ LogicalResult matchAndRewrite (UnPackOp op,
4906
+ PatternRewriter &rewriter) const override {
4907
+ if (!foldTensorCastPrecondition (op))
4908
+ return failure ();
4909
+
4910
+ SmallVector<Type> newResultTypes (op->getResultTypes ());
4911
+ SmallVector<Value> newOperands = getNewOperands (op, newResultTypes);
4912
+ Value sourceTensor = newOperands[0 ];
4913
+
4914
+ // Get the updated mixed-tile-sizes attribute.
4915
+ SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes (
4916
+ rewriter, sourceTensor.getType (), op.getMixedTiles ());
4917
+
4918
+ // Clone op.
4919
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
4920
+ // this point. However, in practice, we use them for things that we'd like
4921
+ // to preserve. Implement a better abstraction.
4922
+ UnPackOp newOp = rewriter.create <UnPackOp>(
4923
+ op.getLoc (), sourceTensor, newOperands[1 ], op.getInnerDimsPos (),
4924
+ newMixedTileSizes, op.getOuterDimsPerm ());
4925
+ newOp->setDiscardableAttrs (op->getDiscardableAttrDictionary ());
4926
+
4927
+ // Replace op.
4928
+ Value oldResult = op.getResult ();
4929
+ Value newResult = newOp.getResult ();
4930
+ Value replacement = (newResult.getType () != oldResult.getType ())
4931
+ ? rewriter.create <tensor::CastOp>(
4932
+ op->getLoc (), oldResult.getType (), newResult)
4933
+ : newResult;
4934
+
4935
+ rewriter.replaceOp (op, {replacement});
4936
+
4937
+ return success ();
4938
+ }
4939
+ };
4940
+
4868
4941
// / Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
4869
4942
// / the `tensor.cast` has source that is more static than the consuming op.
4870
4943
// /
@@ -4890,7 +4963,8 @@ struct FoldTensorCastProducerOp
4890
4963
PatternRewriter &rewriter) const override {
4891
4964
4892
4965
// Reject tensor::PackOp - there's dedicated pattern for that instead.
4893
- if (!foldTensorCastPrecondition (op) || dyn_cast<tensor::PackOp>(*op))
4966
+ if (!foldTensorCastPrecondition (op) ||
4967
+ isa<tensor::PackOp, tensor::UnPackOp>(*op))
4894
4968
return failure ();
4895
4969
4896
4970
SmallVector<Type> newResultTypes (op->getResultTypes ());
@@ -4923,6 +4997,7 @@ struct FoldTensorCastProducerOp
4923
4997
void TensorDialect::getCanonicalizationPatterns (
4924
4998
RewritePatternSet &results) const {
4925
4999
results.add <FoldTensorCastPackOp>(getContext ());
5000
+ results.add <FoldTensorCastUnPackOp>(getContext ());
4926
5001
results.add <FoldTensorCastProducerOp>(getContext ());
4927
5002
}
4928
5003
0 commit comments