-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Add shape inference support for tensor.concat
op.
#140168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
#include "llvm/ADT/STLExtras.h" | ||
#include "llvm/ADT/SmallBitVector.h" | ||
#include "llvm/ADT/StringRef.h" | ||
#include "llvm/Support/LogicalResult.h" | ||
#include "llvm/Support/MathExtras.h" | ||
#include <algorithm> | ||
#include <optional> | ||
|
@@ -330,8 +331,9 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) { | |
|
||
/// Determines whether the tensor::CastOp casts to a more static version of the | ||
/// source tensor. This is useful to fold into a producing op and implement | ||
/// canonicaliation patterns with the `tensor.cast` op as the root, but producer | ||
/// being from different dialects. Returns true when all conditions are met: | ||
/// canonicalization patterns with the `tensor.cast` op as the root, but | ||
/// producer being from different dialects. Returns true when all conditions are | ||
/// met: | ||
/// 1. source and result and ranked tensors with same element type and rank. | ||
/// 2. the result type has more static information than the source. | ||
/// | ||
|
@@ -773,11 +775,111 @@ struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> { | |
return success(); | ||
} | ||
}; | ||
|
||
/// Propagate static shapes into the operands of a `tensor.concat`. | ||
/// | ||
/// `tensor.concat` requires every operand to match on all dimensions except the | ||
/// concatenation dimension. If one operand is already static in those | ||
/// dimensions, the other operands may safely be refined to that same static | ||
/// shape. | ||
/// | ||
/// Example: | ||
/// | ||
/// ```mlir | ||
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) -> | ||
/// tensor<?x12xi32> | ||
/// ``` | ||
/// -> | ||
/// ```mlir | ||
/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32> | ||
/// %2 = tensor.concat dim(0) %0, %cast : | ||
/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32> | ||
/// ``` | ||
struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> { | ||
using OpRewritePattern<ConcatOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ConcatOp concatOp, | ||
PatternRewriter &rewriter) const override { | ||
auto operandTensorTypes = | ||
llvm::map_range(concatOp->getOperandTypes(), [](Type type) { | ||
return llvm::cast<RankedTensorType>(type); | ||
}); | ||
|
||
int64_t dim = concatOp.getDim(); | ||
ArrayRef<int64_t> inferredResultShape = | ||
ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape(); | ||
|
||
// Find operands for which a more static shape can be inferred. | ||
LogicalResult matched = failure(); | ||
for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) { | ||
// Compute inferred type for operand. | ||
Comment on lines
+814
to
+815
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
SmallVector<int64_t> inferredOperandShape(inferredResultShape); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be hoisted outside the loop. |
||
inferredOperandShape[dim] = operandType.getDimSize(dim); | ||
auto inferredOperandType = RankedTensorType::get( | ||
inferredOperandShape, operandType.getElementType()); | ||
|
||
// Check if inferred type is more static. | ||
if (!preservesStaticInformation(inferredOperandType, operandType)) { | ||
matched = success(); | ||
|
||
// Use refined operand type and create cast from original operand. | ||
auto castOp = | ||
rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType, | ||
concatOp.getOperand(operandIdx)); | ||
rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] { | ||
concatOp->setOperand(operandIdx, castOp->getResult(0)); | ||
}); | ||
} | ||
} | ||
|
||
return matched; | ||
} | ||
}; | ||
|
||
// Ensure `tensor.concat`'s result type is at least as static as can be inferred | ||
// from its operand types. | ||
/// | ||
/// Example: | ||
/// ```mlir | ||
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) -> | ||
/// tensor<?x?xi32> | ||
/// ``` | ||
/// -> | ||
/// ```mlir | ||
/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>) | ||
/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to | ||
/// tensor<?x?xi32> | ||
/// ``` | ||
struct InferConcatResultType : public OpRewritePattern<ConcatOp> { | ||
using OpRewritePattern<ConcatOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ConcatOp concatOp, | ||
PatternRewriter &rewriter) const override { | ||
int64_t dim = concatOp.getDim(); | ||
RankedTensorType inferredResultType = | ||
ConcatOp::inferResultType(dim, concatOp->getOperandTypes()); | ||
|
||
// The result type should be at least as static as inferred result type. | ||
if (preservesStaticInformation(inferredResultType, | ||
concatOp.getResultType())) { | ||
return failure(); | ||
} | ||
|
||
auto newConcatOp = rewriter.create<ConcatOp>( | ||
concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands()); | ||
rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(), | ||
newConcatOp); | ||
|
||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
MLIRContext *context) { | ||
results.add<SingleInputConcatOp>(context); | ||
results | ||
.add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>( | ||
context); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to create this extra stack var. Just use it in the fly inside the for loop.