@@ -330,8 +330,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
330
330
331
331
// / Determines whether the tensor::CastOp casts to a more static version of the
332
332
// / source tensor. This is useful to fold into a producing op and implement
333
- // / canonicaliation patterns with the `tensor.cast` op as the root, but producer
334
- // / being from different dialects. Returns true when all conditions are met:
333
+ // / canonicalization patterns with the `tensor.cast` op as the root, but
334
+ // / producer being from different dialects. Returns true when all conditions are
335
+ // / met:
335
336
// / 1. source and result and ranked tensors with same element type and rank.
336
337
// / 2. the result type has more static information than the source.
337
338
// /
@@ -773,11 +774,116 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
773
774
return success ();
774
775
}
775
776
};
777
+
778
+ // / Propagate static shapes into the operands of a `tensor.concat`.
779
+ // /
780
+ // / `tensor.concat` requires every operand to match on all dimensions except the
781
+ // / concatenation dimension. If one operand is already static in those
782
+ // / dimensions, the other operands may safely be refined to that same static
783
+ // / shape.
784
+ // /
785
+ // / Example:
786
+ // /
787
+ // / ```mlir
788
+ // / %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
789
+ // / tensor<?x12xi32>
790
+ // / ```
791
+ // / ->
792
+ // / ```mlir
793
+ // / %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
794
+ // / %2 = tensor.concat dim(0) %0, %cast :
795
+ // / (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
796
+ // / ```
797
+ struct InferConcatOperandTypes : public OpRewritePattern <ConcatOp> {
798
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
799
+
800
+ LogicalResult matchAndRewrite (ConcatOp concatOp,
801
+ PatternRewriter &rewriter) const override {
802
+ auto operandTensorTypes =
803
+ llvm::map_range (concatOp->getOperandTypes (), [](Type type) {
804
+ return llvm::cast<RankedTensorType>(type);
805
+ });
806
+
807
+ int64_t dim = concatOp.getDim ();
808
+ ArrayRef<int64_t > inferredResultShape =
809
+ concatOp.inferResultType (dim, concatOp->getOperandTypes ()).getShape ();
810
+
811
+ // Find operands for which a more static shape can be inferred.
812
+ SmallVector<std::tuple<size_t , RankedTensorType>> refinedTypes;
813
+ for (auto [operandIdx, operandType] : llvm::enumerate (operandTensorTypes)) {
814
+ // Compute inferred type for operand.
815
+ SmallVector<int64_t > inferredOperandShape (inferredResultShape);
816
+ inferredOperandShape[dim] = operandType.getDimSize (dim);
817
+ auto inferredOperandType = RankedTensorType::get (
818
+ inferredOperandShape, operandType.getElementType ());
819
+
820
+ // Check if inferred type is more static.
821
+ if (!preservesStaticInformation (inferredOperandType, operandType)) {
822
+ refinedTypes.push_back ({operandIdx, inferredOperandType});
823
+ }
824
+ }
825
+
826
+ if (refinedTypes.empty ()) {
827
+ return failure ();
828
+ }
829
+
830
+ // Use refined types for operands, insert casts for original type.
831
+ SmallVector<Value> newOperands = concatOp.getOperands ();
832
+ for (auto [operandIdx, refinedType] : refinedTypes) {
833
+ newOperands[operandIdx] = rewriter.create <CastOp>(
834
+ concatOp->getLoc (), refinedType, concatOp.getOperand (operandIdx));
835
+ }
836
+ rewriter.replaceOpWithNewOp <ConcatOp>(concatOp, concatOp.getResultType (),
837
+ dim, newOperands);
838
+
839
+ return success ();
840
+ }
841
+ };
842
+
843
+ // Ensure `tensor.concat`'s result type is at least as static as can be inferred
844
+ // from its operand types.
845
+ // /
846
+ // / Example:
847
+ // / ```mlir
848
+ // / %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
849
+ // / tensor<?x?xi32>
850
+ // / ```
851
+ // / ->
852
+ // / ```mlir
853
+ // / %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
854
+ // / -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
855
+ // / tensor<?x?xi32>
856
+ // / ```
857
+ struct InferConcatResultType : public OpRewritePattern <ConcatOp> {
858
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
859
+
860
+ LogicalResult matchAndRewrite (ConcatOp concatOp,
861
+ PatternRewriter &rewriter) const override {
862
+ int64_t dim = concatOp.getDim ();
863
+ RankedTensorType inferredResultType =
864
+ concatOp.inferResultType (dim, concatOp->getOperandTypes ());
865
+
866
+ // The result type should be at least as static as inferred result type.
867
+ if (preservesStaticInformation (inferredResultType,
868
+ concatOp.getResultType ())) {
869
+ return failure ();
870
+ }
871
+
872
+ auto newConcatOp = rewriter.create <ConcatOp>(
873
+ concatOp->getLoc (), inferredResultType, dim, concatOp->getOperands ());
874
+ rewriter.replaceOpWithNewOp <CastOp>(concatOp, concatOp.getResultType (),
875
+ newConcatOp);
876
+
877
+ return llvm::success ();
878
+ }
879
+ };
776
880
} // namespace
777
881
778
882
void ConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
779
883
MLIRContext *context) {
780
- results.add <SingleInputConcatOp>(context);
884
+ results
885
+ .add <SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
886
+ context);
781
887
}
782
888
783
889
// ===----------------------------------------------------------------------===//
0 commit comments