Skip to content

Commit 27233af

Browse files
committed
tensor.concat cast propagation
Adds canonicalization patterns which propagate inferred static shapes to `tensor.concat` operands and result types. Static is propagated to other canonicalization patterns through casts.
1 parent 6d7b5c3 commit 27233af

File tree

2 files changed

+137
-3
lines changed

2 files changed

+137
-3
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
330330

331331
/// Determines whether the tensor::CastOp casts to a more static version of the
332332
/// source tensor. This is useful to fold into a producing op and implement
333-
/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
334-
/// being from different dialects. Returns true when all conditions are met:
333+
/// canonicalization patterns with the `tensor.cast` op as the root, but
334+
/// producer being from different dialects. Returns true when all conditions are
335+
/// met:
335336
/// 1. source and result and ranked tensors with same element type and rank.
336337
/// 2. the result type has more static information than the source.
337338
///
@@ -773,11 +774,118 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
773774
return success();
774775
}
775776
};
777+
778+
/// Propagate static shapes into the operands of a `tensor.concat`.
779+
///
780+
/// `tensor.concat` requires every operand to match on all dimensions except the
781+
/// concatenation dimension. If one operand is already static in those
782+
/// dimensions, the other operands may safely be refined to that same static
783+
/// shape.
784+
///
785+
/// Example:
786+
///
787+
/// ```mlir
788+
/// // Second operand dim 1 has dynamic shape constrained by dim 1 of first
789+
/// // operand.
790+
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
791+
/// tensor<?x12xi32>
792+
/// ```
793+
/// ->
794+
/// ```mlir
795+
/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
796+
/// %2 = tensor.concat dim(0) %0, %cast :
797+
/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
798+
/// ```
799+
struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
800+
using OpRewritePattern<ConcatOp>::OpRewritePattern;
801+
802+
LogicalResult matchAndRewrite(ConcatOp concatOp,
803+
PatternRewriter &rewriter) const override {
804+
auto operandTensorTypes =
805+
llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
806+
return llvm::cast<RankedTensorType>(type);
807+
});
808+
809+
int64_t dim = concatOp.getDim();
810+
ArrayRef<int64_t> inferredResultShape =
811+
concatOp.inferResultType(dim, concatOp->getOperandTypes()).getShape();
812+
813+
// Find operands for which a more static shape can be inferred.
814+
SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
815+
for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
816+
// Compute inferred type for operand.
817+
SmallVector<int64_t> inferredOperandShape(inferredResultShape);
818+
inferredOperandShape[dim] = operandType.getDimSize(dim);
819+
auto inferredOperandType = RankedTensorType::get(
820+
inferredOperandShape, operandType.getElementType());
821+
822+
// Check if inferred type is more static.
823+
if (!preservesStaticInformation(inferredOperandType, operandType)) {
824+
refinedTypes.push_back({operandIdx, inferredOperandType});
825+
}
826+
}
827+
828+
if (refinedTypes.empty()) {
829+
return failure();
830+
}
831+
832+
// Use refined types for operands, insert casts for original type.
833+
SmallVector<Value> newOperands = concatOp.getOperands();
834+
for (auto [operandIdx, refinedType] : refinedTypes) {
835+
newOperands[operandIdx] = rewriter.create<CastOp>(
836+
concatOp->getLoc(), refinedType, concatOp.getOperand(operandIdx));
837+
}
838+
rewriter.replaceOpWithNewOp<ConcatOp>(concatOp, concatOp.getResultType(),
839+
dim, newOperands);
840+
841+
return success();
842+
}
843+
};
844+
845+
// Ensure `tensor.concat`'s result type is at least as static as can be inferred
846+
// from its operand types.
847+
///
848+
/// Example:
849+
/// ```mlir
850+
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
851+
/// tensor<?x?xi32>
852+
/// ```
853+
/// ->
854+
/// ```mlir
855+
/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
856+
/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
857+
/// tensor<?x?xi32>
858+
/// ```
859+
struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
860+
using OpRewritePattern<ConcatOp>::OpRewritePattern;
861+
862+
LogicalResult matchAndRewrite(ConcatOp concatOp,
863+
PatternRewriter &rewriter) const override {
864+
int64_t dim = concatOp.getDim();
865+
RankedTensorType inferredResultType =
866+
concatOp.inferResultType(dim, concatOp->getOperandTypes());
867+
868+
// The result type should be at least as static as inferred result type.
869+
if (preservesStaticInformation(inferredResultType,
870+
concatOp.getResultType())) {
871+
return failure();
872+
}
873+
874+
auto newConcatOp = rewriter.create<ConcatOp>(
875+
concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
876+
rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
877+
newConcatOp);
878+
879+
return llvm::success();
880+
}
881+
};
776882
} // namespace
777883

778884
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
779885
MLIRContext *context) {
780-
results.add<SingleInputConcatOp>(context);
886+
results
887+
.add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
888+
context);
781889
}
782890

783891
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,32 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1
136136

137137
// -----
138138

139+
// CHECK-LABEL: infer_concat_operand_types
140+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x12xi32>
141+
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xi32>
142+
func.func @infer_concat_operand_types(%arg0: tensor<?x12xi32>, %arg1: tensor<?x?xi32>) -> (tensor<?x12xi32>) {
143+
// CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xi32> to tensor<?x12xi32>
144+
%0 = tensor.concat dim(0) %arg0, %arg1: (tensor<?x12xi32>, tensor<?x?xi32>) -> tensor<?x12xi32>
145+
// CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[CAST]] : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
146+
return %0 : tensor<?x12xi32>
147+
// CHECK-NEXT: return %[[CONCAT]] : tensor<?x12xi32>
148+
}
149+
150+
// -----
151+
152+
// CHECK-LABEL: infer_concat_return_type
153+
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x12xi32>
154+
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x12xi32>
155+
func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12xi32>) -> (tensor<?x?xi32>) {
156+
%0 = tensor.concat dim(0) %arg0, %arg1: (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x?xi32>
157+
// CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[ARG1]] : (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
158+
// CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<?x12xi32> to tensor<?x?xi32>
159+
return %0 : tensor<?x?xi32>
160+
// CHECK-NEXT: return %[[CAST]] : tensor<?x?xi32>
161+
}
162+
163+
// -----
164+
139165
// CHECK-LABEL: func @fold_extract
140166
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
141167
%const_0 = arith.constant 0 : index

0 commit comments

Comments
 (0)