33
33
#include " llvm/ADT/STLExtras.h"
34
34
#include " llvm/ADT/SmallBitVector.h"
35
35
#include " llvm/ADT/StringRef.h"
36
+ #include " llvm/Support/LogicalResult.h"
36
37
#include " llvm/Support/MathExtras.h"
37
38
#include < algorithm>
38
39
#include < optional>
@@ -809,7 +810,7 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
809
810
ConcatOp::inferResultType (dim, concatOp->getOperandTypes ()).getShape ();
810
811
811
812
// Find operands for which a more static shape can be inferred.
812
- SmallVector<std::tuple< size_t , RankedTensorType>> refinedTypes ;
813
+ LogicalResult matched = failure () ;
813
814
for (auto [operandIdx, operandType] : llvm::enumerate (operandTensorTypes)) {
814
815
// Compute inferred type for operand.
815
816
SmallVector<int64_t > inferredOperandShape (inferredResultShape);
@@ -819,24 +820,20 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
819
820
820
821
// Check if inferred type is more static.
821
822
if (!preservesStaticInformation (inferredOperandType, operandType)) {
822
- refinedTypes.push_back ({operandIdx, inferredOperandType});
823
+ matched = success ();
824
+
825
+ // Use refined operand type and create cast from original operand.
826
+ auto castOp =
827
+ rewriter.create <CastOp>(concatOp->getLoc (), inferredOperandType,
828
+ concatOp.getOperand (operandIdx));
829
+ rewriter.modifyOpInPlace (
830
+ concatOp, [=, operandIdx = (size_t )operandIdx] {
831
+ concatOp->setOperand (operandIdx, castOp->getResult (0 ));
832
+ });
823
833
}
824
834
}
825
835
826
- if (refinedTypes.empty ()) {
827
- return failure ();
828
- }
829
-
830
- // Use refined types for operands, insert casts for original type.
831
- SmallVector<Value> newOperands = concatOp.getOperands ();
832
- for (auto [operandIdx, refinedType] : refinedTypes) {
833
- newOperands[operandIdx] = rewriter.create <CastOp>(
834
- concatOp->getLoc (), refinedType, concatOp.getOperand (operandIdx));
835
- }
836
- rewriter.replaceOpWithNewOp <ConcatOp>(concatOp, concatOp.getResultType (),
837
- dim, newOperands);
838
-
839
- return success ();
836
+ return matched;
840
837
}
841
838
};
842
839
0 commit comments