@@ -800,23 +800,22 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
800
800
801
801
LogicalResult matchAndRewrite (ConcatOp concatOp,
802
802
PatternRewriter &rewriter) const override {
803
- auto operandTensorTypes =
804
- llvm::map_range (concatOp->getOperandTypes (), [](Type type) {
805
- return llvm::cast<RankedTensorType>(type);
806
- });
807
-
808
803
int64_t dim = concatOp.getDim ();
809
- ArrayRef< int64_t > inferredResultShape =
810
- ConcatOp::inferResultType (dim, concatOp->getOperandTypes ()). getShape () ;
804
+ RankedTensorType inferredResultType =
805
+ ConcatOp::inferResultType (dim, concatOp->getOperandTypes ());
811
806
812
807
// Find operands for which a more static shape can be inferred.
813
808
LogicalResult matched = failure ();
814
- for (auto [operandIdx, operandType] : llvm::enumerate (operandTensorTypes)) {
809
+ // Inferred operand shapes are identical in every dimension except the
810
+ // concatenation dimension.
811
+ SmallVector<int64_t > inferredOperandShape (inferredResultType.getShape ());
812
+ for (auto [operandIdx, operandType] :
813
+ llvm::enumerate (concatOp->getOperandTypes ())) {
815
814
// Compute inferred type for operand.
816
- SmallVector< int64_t > inferredOperandShape (inferredResultShape);
817
- inferredOperandShape[dim] = operandType.getDimSize (dim);
815
+ inferredOperandShape[dim] =
816
+ cast<RankedTensorType>( operandType) .getDimSize (dim);
818
817
auto inferredOperandType = RankedTensorType::get (
819
- inferredOperandShape, operandType .getElementType ());
818
+ inferredOperandShape, inferredResultType .getElementType ());
820
819
821
820
// Check if inferred type is more static.
822
821
if (!preservesStaticInformation (inferredOperandType, operandType)) {
0 commit comments