Skip to content

Commit b0d1405

Browse files
committed
Code review comment regarding setOperand
1 parent 0bcd656 commit b0d1405

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/ADT/STLExtras.h"
3434
#include "llvm/ADT/SmallBitVector.h"
3535
#include "llvm/ADT/StringRef.h"
36+
#include "llvm/Support/LogicalResult.h"
3637
#include "llvm/Support/MathExtras.h"
3738
#include <algorithm>
3839
#include <optional>
@@ -809,7 +810,7 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
809810
ConcatOp::inferResultType(dim, concatOp->getOperandTypes()).getShape();
810811

811812
// Find operands for which a more static shape can be inferred.
812-
SmallVector<std::tuple<size_t, RankedTensorType>> refinedTypes;
813+
LogicalResult matched = failure();
813814
for (auto [operandIdx, operandType] : llvm::enumerate(operandTensorTypes)) {
814815
// Compute inferred type for operand.
815816
SmallVector<int64_t> inferredOperandShape(inferredResultShape);
@@ -819,24 +820,20 @@ struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
819820

820821
// Check if inferred type is more static.
821822
if (!preservesStaticInformation(inferredOperandType, operandType)) {
822-
refinedTypes.push_back({operandIdx, inferredOperandType});
823+
matched = success();
824+
825+
// Use refined operand type and create cast from original operand.
826+
auto castOp =
827+
rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
828+
concatOp.getOperand(operandIdx));
829+
rewriter.modifyOpInPlace(
830+
concatOp, [=, operandIdx = (size_t)operandIdx] {
831+
concatOp->setOperand(operandIdx, castOp->getResult(0));
832+
});
823833
}
824834
}
825835

826-
if (refinedTypes.empty()) {
827-
return failure();
828-
}
829-
830-
// Use refined types for operands, insert casts for original type.
831-
SmallVector<Value> newOperands = concatOp.getOperands();
832-
for (auto [operandIdx, refinedType] : refinedTypes) {
833-
newOperands[operandIdx] = rewriter.create<CastOp>(
834-
concatOp->getLoc(), refinedType, concatOp.getOperand(operandIdx));
835-
}
836-
rewriter.replaceOpWithNewOp<ConcatOp>(concatOp, concatOp.getResultType(),
837-
dim, newOperands);
838-
839-
return success();
836+
return matched;
840837
}
841838
};
842839

0 commit comments

Comments
 (0)