@@ -330,8 +330,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
330
330
331
331
// / Determines whether the tensor::CastOp casts to a more static version of the
332
332
// / source tensor. This is useful to fold into a producing op and implement
333
- // / canonicaliation patterns with the `tensor.cast` op as the root, but producer
334
- // / being from different dialects. Returns true when all conditions are met:
333
+ // / canonicalization patterns with the `tensor.cast` op as the root, but
334
+ // / producer being from different dialects. Returns true when all conditions are
335
+ // / met:
335
336
// / 1. source and result and ranked tensors with same element type and rank.
336
337
// / 2. the result type has more static information than the source.
337
338
// /
@@ -773,11 +774,118 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
773
774
return success ();
774
775
}
775
776
};
777
+
778
+ // / Propagate static shapes into the operands of a `tensor.concat`.
779
+ // /
780
+ // / `tensor.concat` requires every operand to match on all dimensions except the
781
+ // / concatenation dimension. If one operand is already static in those
782
+ // / dimensions, the other operands may safely be refined to that same static
783
+ // / shape.
784
+ // /
785
+ // / Example:
786
+ // /
787
+ // / ```mlir
788
+ // / // Second operand dim 1 has dynamic shape constrained by dim 1 of first
789
+ // / // operand.
790
+ // / %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
791
+ // / tensor<?x12xi32>
792
+ // / ```
793
+ // / ->
794
+ // / ```mlir
795
+ // / %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
796
+ // / %2 = tensor.concat dim(0) %0, %cast :
797
+ // / (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
798
+ // / ```
799
+ struct InferConcatOperandTypes : public OpRewritePattern <ConcatOp> {
800
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
801
+
802
+ LogicalResult matchAndRewrite (ConcatOp concatOp,
803
+ PatternRewriter &rewriter) const override {
804
+ auto operandTensorTypes =
805
+ llvm::map_range (concatOp->getOperandTypes (), [](Type type) {
806
+ return llvm::cast<RankedTensorType>(type);
807
+ });
808
+
809
+ int64_t dim = concatOp.getDim ();
810
+ ArrayRef<int64_t > inferredResultShape =
811
+ concatOp.inferResultType (dim, concatOp->getOperandTypes ()).getShape ();
812
+
813
+ // Find operands for which a more static shape can be inferred.
814
+ SmallVector<std::tuple<size_t , RankedTensorType>> refinedTypes;
815
+ for (auto [operandIdx, operandType] : llvm::enumerate (operandTensorTypes)) {
816
+ // Compute inferred type for operand.
817
+ SmallVector<int64_t > inferredOperandShape (inferredResultShape);
818
+ inferredOperandShape[dim] = operandType.getDimSize (dim);
819
+ auto inferredOperandType = RankedTensorType::get (
820
+ inferredOperandShape, operandType.getElementType ());
821
+
822
+ // Check if inferred type is more static.
823
+ if (!preservesStaticInformation (inferredOperandType, operandType)) {
824
+ refinedTypes.push_back ({operandIdx, inferredOperandType});
825
+ }
826
+ }
827
+
828
+ if (refinedTypes.empty ()) {
829
+ return failure ();
830
+ }
831
+
832
+ // Use refined types for operands, insert casts for original type.
833
+ SmallVector<Value> newOperands = concatOp.getOperands ();
834
+ for (auto [operandIdx, refinedType] : refinedTypes) {
835
+ newOperands[operandIdx] = rewriter.create <CastOp>(
836
+ concatOp->getLoc (), refinedType, concatOp.getOperand (operandIdx));
837
+ }
838
+ rewriter.replaceOpWithNewOp <ConcatOp>(concatOp, concatOp.getResultType (),
839
+ dim, newOperands);
840
+
841
+ return success ();
842
+ }
843
+ };
844
+
845
+ // Ensure `tensor.concat`'s result type is at least as static as can be inferred
846
+ // from its operand types.
847
+ // /
848
+ // / Example:
849
+ // / ```mlir
850
+ // / %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
851
+ // / tensor<?x?xi32>
852
+ // / ```
853
+ // / ->
854
+ // / ```mlir
855
+ // / %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
856
+ // / -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
857
+ // / tensor<?x?xi32>
858
+ // / ```
859
+ struct InferConcatResultType : public OpRewritePattern <ConcatOp> {
860
+ using OpRewritePattern<ConcatOp>::OpRewritePattern;
861
+
862
+ LogicalResult matchAndRewrite (ConcatOp concatOp,
863
+ PatternRewriter &rewriter) const override {
864
+ int64_t dim = concatOp.getDim ();
865
+ RankedTensorType inferredResultType =
866
+ concatOp.inferResultType (dim, concatOp->getOperandTypes ());
867
+
868
+ // The result type should be at least as static as inferred result type.
869
+ if (preservesStaticInformation (inferredResultType,
870
+ concatOp.getResultType ())) {
871
+ return failure ();
872
+ }
873
+
874
+ auto newConcatOp = rewriter.create <ConcatOp>(
875
+ concatOp->getLoc (), inferredResultType, dim, concatOp->getOperands ());
876
+ rewriter.replaceOpWithNewOp <CastOp>(concatOp, concatOp.getResultType (),
877
+ newConcatOp);
878
+
879
+ return llvm::success ();
880
+ }
881
+ };
776
882
} // namespace
777
883
778
884
void ConcatOp::getCanonicalizationPatterns (RewritePatternSet &results,
779
885
MLIRContext *context) {
780
- results.add <SingleInputConcatOp>(context);
886
+ results
887
+ .add <SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
888
+ context);
781
889
}
782
890
783
891
// ===----------------------------------------------------------------------===//
0 commit comments