Skip to content

Commit db8aaff

Browse files
committed
tensor.concat cast propagation
Adds canonicalization patterns which propagate inferred static shapes to `tensor.concat` operands and result types. Static is propagated to other canonicalization patterns through casts.
1 parent 6d7b5c3 commit db8aaff

File tree

2 files changed

+135
-3
lines changed

2 files changed

+135
-3
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 109 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
330330

331331
/// Determines whether the tensor::CastOp casts to a more static version of the
332332
/// source tensor. This is useful to fold into a producing op and implement
333-
/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
334-
/// being from different dialects. Returns true when all conditions are met:
333+
/// canonicalization patterns with the `tensor.cast` op as the root, but
334+
/// producer being from different dialects. Returns true when all conditions are
335+
/// met:
335336
/// 1. source and result and ranked tensors with same element type and rank.
336337
/// 2. the result type has more static information than the source.
337338
///
@@ -773,11 +774,116 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
773774
return success();
774775
}
775776
};
777+
778+
/// Propagate static shapes into the operands of a `tensor.concat`.
779+
///
780+
/// `tensor.concat` requires every operand to match on all dimensions except the
781+
/// concatenation dimension. If one operand is already static in those
782+
/// dimensions, the other operands may safely be refined to that same static
783+
/// shape.
784+
///
785+
/// Example:
786+
///
787+
/// ```mlir
788+
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
789+
/// tensor<?x12xi32>
790+
/// ```
791+
/// ->
792+
/// ```mlir
793+
/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
794+
/// %2 = tensor.concat dim(0) %0, %cast :
795+
/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
796+
/// ```
797+
struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
798+
using OpRewritePattern<ConcatOp>::OpRewritePattern;
799+
800+
LogicalResult matchAndRewrite(ConcatOp concatOp,
801+
PatternRewriter &rewriter) const override {
802+
auto operandTensorTypes =
803+
llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
804+
return llvm::cast<RankedTensorType>(type);
805+
});
806+
807+
int64_t dim = concatOp.getDim();
808+
ArrayRef<int64_t> inferredResultShape =
809+
concatOp.inferResultType(dim, concatOp->getOperandTypes()).getShape();
810+
811+
// Find operands for which a more static shape can be inferred.
812+
SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
813+
for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
814+
// Compute inferred type for operand.
815+
SmallVector<int64_t> inferredOperandShape(inferredResultShape);
816+
inferredOperandShape[dim] = operandType.getDimSize(dim);
817+
auto inferredOperandType = RankedTensorType::get(
818+
inferredOperandShape, operandType.getElementType());
819+
820+
// Check if inferred type is more static.
821+
if (!preservesStaticInformation(inferredOperandType, operandType)) {
822+
refinedTypes.push_back({operandIdx, inferredOperandType});
823+
}
824+
}
825+
826+
if (refinedTypes.empty()) {
827+
return failure();
828+
}
829+
830+
// Use refined types for operands, insert casts for original type.
831+
SmallVector<Value> newOperands = concatOp.getOperands();
832+
for (auto [operandIdx, refinedType] : refinedTypes) {
833+
newOperands[operandIdx] = rewriter.create<CastOp>(
834+
concatOp->getLoc(), refinedType, concatOp.getOperand(operandIdx));
835+
}
836+
rewriter.replaceOpWithNewOp<ConcatOp>(concatOp, concatOp.getResultType(),
837+
dim, newOperands);
838+
839+
return success();
840+
}
841+
};
842+
843+
// Ensure `tensor.concat`'s result type is at least as static as can be inferred
844+
// from its operand types.
845+
///
846+
/// Example:
847+
/// ```mlir
848+
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
849+
/// tensor<?x?xi32>
850+
/// ```
851+
/// ->
852+
/// ```mlir
853+
/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
854+
/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
855+
/// tensor<?x?xi32>
856+
/// ```
857+
struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
858+
using OpRewritePattern<ConcatOp>::OpRewritePattern;
859+
860+
LogicalResult matchAndRewrite(ConcatOp concatOp,
861+
PatternRewriter &rewriter) const override {
862+
int64_t dim = concatOp.getDim();
863+
RankedTensorType inferredResultType =
864+
concatOp.inferResultType(dim, concatOp->getOperandTypes());
865+
866+
// The result type should be at least as static as inferred result type.
867+
if (preservesStaticInformation(inferredResultType,
868+
concatOp.getResultType())) {
869+
return failure();
870+
}
871+
872+
auto newConcatOp = rewriter.create<ConcatOp>(
873+
concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
874+
rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
875+
newConcatOp);
876+
877+
return llvm::success();
878+
}
879+
};
776880
} // namespace
777881

778882
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
779883
MLIRContext *context) {
780-
results.add<SingleInputConcatOp>(context);
884+
results
885+
.add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
886+
context);
781887
}
782888

783889
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,32 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
136136

137137
// -----
138138

139+
// CHECK-LABEL: infer_concat_operand_types
140+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x12xi32>
141+
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xi32>
142+
func.func @infer_concat_operand_types(%arg0: tensor<?x12xi32>, %arg1: tensor<?x?xi32>) -> (tensor<?x12xi32>) {
143+
// CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xi32> to tensor<?x12xi32>
144+
%0 = tensor.concat dim(0) %arg0, %arg1: (tensor<?x12xi32>, tensor<?x?xi32>) -> tensor<?x12xi32>
145+
// CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[CAST]] : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
146+
return %0 : tensor<?x12xi32>
147+
// CHECK-NEXT: return %[[CONCAT]] : tensor<?x12xi32>
148+
}
149+
150+
// -----
151+
152+
// CHECK-LABEL: infer_concat_return_type
153+
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x12xi32>
154+
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x12xi32>
155+
func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12xi32>) -> (tensor<?x?xi32>) {
156+
%0 = tensor.concat dim(0) %arg0, %arg1: (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x?xi32>
157+
// CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[ARG1]] : (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
158+
// CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<?x12xi32> to tensor<?x?xi32>
159+
return %0 : tensor<?x?xi32>
160+
// CHECK-NEXT: return %[[CAST]] : tensor<?x?xi32>
161+
}
162+
163+
// -----
164+
139165
// CHECK-LABEL: func @fold_extract
140166
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
141167
%const_0 = arith.constant 0 : index

0 commit comments

Comments
 (0)