Skip to content

Commit a0a55df

Browse files
[mlir][tensor][NFC] Code cleanup around shape inference support for tensor.concat op (#140616)
Addresses some code review on #140168 that came in after merge.
1 parent df0358f commit a0a55df

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -800,23 +800,22 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
800800

801801
LogicalResult matchAndRewrite(ConcatOp concatOp,
802802
PatternRewriter &rewriter) const override {
803-
auto operandTensorTypes =
804-
llvm::map_range(concatOp->getOperandTypes(), [](Type type) {
805-
return llvm::cast<RankedTensorType>(type);
806-
});
807-
808803
int64_t dim = concatOp.getDim();
809-
ArrayRef<int64_t> inferredResultShape =
810-
ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
804+
RankedTensorType inferredResultType =
805+
ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
811806

812807
// Find operands for which a more static shape can be inferred.
813808
LogicalResult matched = failure();
814-
for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
809+
// Inferred operand shapes are identical in every dimension except the
810+
// concatenation dimension.
811+
SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
812+
for (auto [operandIdx, operandType] :
813+
llvm::enumerate(concatOp->getOperandTypes())) {
815814
// Compute inferred type for operand.
816-
SmallVector<int64_t> inferredOperandShape(inferredResultShape);
817-
inferredOperandShape[dim] = operandType.getDimSize(dim);
815+
inferredOperandShape[dim] =
816+
cast<RankedTensorType>(operandType).getDimSize(dim);
818817
auto inferredOperandType = RankedTensorType::get(
819-
inferredOperandShape, operandType.getElementType());
818+
inferredOperandShape, inferredResultType.getElementType());
820819

821820
// Check if inferred type is more static.
822821
if (!preservesStaticInformation(inferredOperandType, operandType)) {

0 commit comments

Comments
 (0)