Skip to content

[mlir][tensor] Add shape inference support for tensor.concat op. #140168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 105 additions & 3 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <algorithm>
#include <optional>
Expand Down Expand Up @@ -330,8 +331,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {

/// Determines whether the tensor::CastOp casts to a more static version of the
/// source tensor. This is useful to fold into a producing op and implement
/// canonicaliation patterns with the `tensor.cast` op as the root, but producer
/// being from different dialects. Returns true when all conditions are met:
/// canonicalization patterns with the `tensor.cast` op as the root, but
/// producer being from different dialects. Returns true when all conditions are
/// met:
/// 1. source and result and ranked tensors with same element type and rank.
/// 2. the result type has more static information than the source.
///
Expand Down Expand Up @@ -773,11 +775,111 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
return success();
}
};

/// Propagate static shapes into the operands of a `tensor.concat`.
///
/// `tensor.concat` requires every operand to match on all dimensions except the
/// concatenation dimension. If one operand is already static in those
/// dimensions, the other operands may safely be refined to that same static
/// shape.
///
/// Example:
///
/// ```mlir
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
/// tensor<?x12xi32>
/// ```
/// ->
/// ```mlir
/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
/// %2 = tensor.concat dim(0) %0, %cast :
/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
/// ```
struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
using OpRewritePattern<ConcatOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
auto operandTensorTypes =
llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
return llvm::cast<RankedTensorType>(type);
});
Comment on lines +803 to +806
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to create this extra stack var. Just use it in the fly inside the for loop.


int64_t dim = concatOp.getDim();
ArrayRef<int64_t> inferredResultShape =
ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();

// Find operands for which a more static shape can be inferred.
LogicalResult matched = failure();
for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
// Compute inferred type for operand.
Comment on lines +814 to +815
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for (auto [operandIdx, operandType] : llvm::enumerate(concatOp->getOperandType()))
Then you can just add the cast below where it's used.

SmallVector<int64_t> inferredOperandShape(inferredResultShape);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be hoisted outside the loop.

inferredOperandShape[dim] = operandType.getDimSize(dim);
auto inferredOperandType = RankedTensorType::get(
inferredOperandShape, operandType.getElementType());

// Check if inferred type is more static.
if (!preservesStaticInformation(inferredOperandType, operandType)) {
matched = success();

// Use refined operand type and create cast from original operand.
auto castOp =
rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
concatOp.getOperand(operandIdx));
rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
concatOp->setOperand(operandIdx, castOp->getResult(0));
});
}
}

return matched;
}
};

// Ensure `tensor.concat`'s result type is at least as static as can be inferred
// from its operand types.
///
/// Example:
/// ```mlir
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
/// tensor<?x?xi32>
/// ```
/// ->
/// ```mlir
/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
/// tensor<?x?xi32>
/// ```
struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
using OpRewritePattern<ConcatOp>::OpRewritePattern;

LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
int64_t dim = concatOp.getDim();
RankedTensorType inferredResultType =
ConcatOp::inferResultType(dim, concatOp->getOperandTypes());

// The result type should be at least as static as inferred result type.
if (preservesStaticInformation(inferredResultType,
concatOp.getResultType())) {
return failure();
}

auto newConcatOp = rewriter.create<ConcatOp>(
concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
newConcatOp);

return success();
}
};
} // namespace

void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<SingleInputConcatOp>(context);
results
.add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
context);
}

//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,32 @@ func.func @fold_concat(%arg0: tensor<1x2x?xi32>) -> (tensor<1x2x3xi32>, tensor<1

// -----

// CHECK-LABEL: infer_concat_operand_types
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x12xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x?xi32>
func.func @infer_concat_operand_types(%arg0: tensor<?x12xi32>, %arg1: tensor<?x?xi32>) -> (tensor<?x12xi32>) {
// CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor<?x?xi32> to tensor<?x12xi32>
%0 = tensor.concat dim(0) %arg0, %arg1: (tensor<?x12xi32>, tensor<?x?xi32>) -> tensor<?x12xi32>
// CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[CAST]] : (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
return %0 : tensor<?x12xi32>
// CHECK-NEXT: return %[[CONCAT]] : tensor<?x12xi32>
}

// -----

// CHECK-LABEL: infer_concat_return_type
// CHECK-SAME: %[[ARG0:.+]]: tensor<5x12xi32>
// CHECK-SAME: %[[ARG1:.+]]: tensor<?x12xi32>
func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12xi32>) -> (tensor<?x?xi32>) {
%0 = tensor.concat dim(0) %arg0, %arg1: (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x?xi32>
// CHECK-NEXT: %[[CONCAT:.+]] = tensor.concat dim(0) %[[ARG0]], %[[ARG1]] : (tensor<5x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
// CHECK-NEXT: %[[CAST:.+]] = tensor.cast %[[CONCAT]] : tensor<?x12xi32> to tensor<?x?xi32>
return %0 : tensor<?x?xi32>
// CHECK-NEXT: return %[[CAST]] : tensor<?x?xi32>
}

// -----

// CHECK-LABEL: func @fold_extract
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
%const_0 = arith.constant 0 : index
Expand Down