|
33 | 33 | #include "llvm/ADT/STLExtras.h"
|
34 | 34 | #include "llvm/ADT/SmallBitVector.h"
|
35 | 35 | #include "llvm/ADT/StringRef.h"
|
| 36 | +#include "llvm/Support/LogicalResult.h" |
36 | 37 | #include "llvm/Support/MathExtras.h"
|
37 | 38 | #include <algorithm>
|
38 | 39 | #include <optional>
|
@@ -330,8 +331,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
330 | 331 |
|
331 | 332 | /// Determines whether the tensor::CastOp casts to a more static version of the
|
332 | 333 | /// source tensor. This is useful to fold into a producing op and implement
|
333 |
| -/// canonicaliation patterns with the `tensor.cast` op as the root, but producer |
334 |
| -/// being from different dialects. Returns true when all conditions are met: |
| 334 | +/// canonicalization patterns with the `tensor.cast` op as the root, but |
| 335 | +/// producer being from different dialects. Returns true when all conditions are |
| 336 | +/// met: |
335 | 337 | /// 1. source and result and ranked tensors with same element type and rank.
|
336 | 338 | /// 2. the result type has more static information than the source.
|
337 | 339 | ///
|
@@ -773,11 +775,111 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
|
773 | 775 | return success();
|
774 | 776 | }
|
775 | 777 | };
|
| 778 | + |
| 779 | +/// Propagate static shapes into the operands of a `tensor.concat`. |
| 780 | +/// |
| 781 | +/// `tensor.concat` requires every operand to match on all dimensions except the |
| 782 | +/// concatenation dimension. If one operand is already static in those |
| 783 | +/// dimensions, the other operands may safely be refined to that same static |
| 784 | +/// shape. |
| 785 | +/// |
| 786 | +/// Example: |
| 787 | +/// |
| 788 | +/// ```mlir |
| 789 | +/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) -> |
| 790 | +/// tensor<?x12xi32> |
| 791 | +/// ``` |
| 792 | +/// -> |
| 793 | +/// ```mlir |
| 794 | +/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32> |
| 795 | +/// %2 = tensor.concat dim(0) %0, %cast : |
| 796 | +/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32> |
| 797 | +/// ``` |
| 798 | +struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> { |
| 799 | + using OpRewritePattern<ConcatOp>::OpRewritePattern; |
| 800 | + |
| 801 | + LogicalResult matchAndRewrite(ConcatOp concatOp, |
| 802 | + PatternRewriter &rewriter) const override { |
| 803 | + auto operandTensorTypes = |
| 804 | + llvm::map_range(concatOp->getOperandTypes(), [](Type type) { |
| 805 | + return llvm::cast<RankedTensorType>(type); |
| 806 | + }); |
| 807 | + |
| 808 | + int64_t dim = concatOp.getDim(); |
| 809 | + ArrayRef<int64_t> inferredResultShape = |
| 810 | + ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape(); |
| 811 | + |
| 812 | + // Find operands for which a more static shape can be inferred. |
| 813 | + LogicalResult matched = failure(); |
| 814 | + for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) { |
| 815 | + // Compute inferred type for operand. |
| 816 | + SmallVector<int64_t> inferredOperandShape(inferredResultShape); |
| 817 | + inferredOperandShape[dim] = operandType.getDimSize(dim); |
| 818 | + auto inferredOperandType = RankedTensorType::get( |
| 819 | + inferredOperandShape, operandType.getElementType()); |
| 820 | + |
| 821 | + // Check if inferred type is more static. |
| 822 | + if (!preservesStaticInformation(inferredOperandType, operandType)) { |
| 823 | + matched = success(); |
| 824 | + |
| 825 | + // Use refined operand type and create cast from original operand. |
| 826 | + auto castOp = |
| 827 | + rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType, |
| 828 | + concatOp.getOperand(operandIdx)); |
| 829 | + rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] { |
| 830 | + concatOp->setOperand(operandIdx, castOp->getResult(0)); |
| 831 | + }); |
| 832 | + } |
| 833 | + } |
| 834 | + |
| 835 | + return matched; |
| 836 | + } |
| 837 | +}; |
| 838 | + |
| 839 | +// Ensure `tensor.concat`'s result type is at least as static as can be inferred |
| 840 | +// from its operand types. |
| 841 | +/// |
| 842 | +/// Example: |
| 843 | +/// ```mlir |
| 844 | +/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) -> |
| 845 | +/// tensor<?x?xi32> |
| 846 | +/// ``` |
| 847 | +/// -> |
| 848 | +/// ```mlir |
| 849 | +/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>) |
| 850 | +/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to |
| 851 | +/// tensor<?x?xi32> |
| 852 | +/// ``` |
| 853 | +struct InferConcatResultType : public OpRewritePattern<ConcatOp> { |
| 854 | + using OpRewritePattern<ConcatOp>::OpRewritePattern; |
| 855 | + |
| 856 | + LogicalResult matchAndRewrite(ConcatOp concatOp, |
| 857 | + PatternRewriter &rewriter) const override { |
| 858 | + int64_t dim = concatOp.getDim(); |
| 859 | + RankedTensorType inferredResultType = |
| 860 | + ConcatOp::inferResultType(dim, concatOp->getOperandTypes()); |
| 861 | + |
| 862 | + // The result type should be at least as static as inferred result type. |
| 863 | + if (preservesStaticInformation(inferredResultType, |
| 864 | + concatOp.getResultType())) { |
| 865 | + return failure(); |
| 866 | + } |
| 867 | + |
| 868 | + auto newConcatOp = rewriter.create<ConcatOp>( |
| 869 | + concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands()); |
| 870 | + rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(), |
| 871 | + newConcatOp); |
| 872 | + |
| 873 | + return success(); |
| 874 | + } |
| 875 | +}; |
776 | 876 | } // namespace
|
777 | 877 |
|
778 | 878 | void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
779 | 879 | MLIRContext *context) {
|
780 |
| - results.add<SingleInputConcatOp>(context); |
| 880 | + results |
| 881 | + .add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>( |
| 882 | + context); |
781 | 883 | }
|
782 | 884 |
|
783 | 885 | //===----------------------------------------------------------------------===//
|
|
0 commit comments