-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ODS] Change get...Mutable
to return OpOperand &
for single operands
#66519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ODS] Change get...Mutable
to return OpOperand &
for single operands
#66519
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-tensor ChangesThe TableGen code generator now generates C++ code that returns a single E.g.:
Depends on #66515. Review only the top commit. Full diff: https://github.com/llvm/llvm-project/pull/66519.diff 11 Files Affected:
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 51d4e366e4970d5..4e550fe3e3a60e6 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -268,6 +268,9 @@ class OpOperand : public IROperand<OpOperand, Value> {
/// Return which operand this is in the OpOperand list of the Operation.
unsigned getOperandNumber();
+ /// Set the current value being used by this operand.
+ void assign(Value value) { set(value); }
+
private:
/// Keep the constructor private and accessible to the OperandStorage class
/// only to avoid hard-to-debug typo/programming mistakes.
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index 187185b47b66695..4546f0fe4bf48c5 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -126,6 +126,9 @@ class MutableOperandRange {
ArrayRef<OperandSegment> operandSegments = std::nullopt);
MutableOperandRange(Operation *owner);
+ /// Construct a new mutable range for the given OpOperand.
+ MutableOperandRange(OpOperand &opOperand);
+
/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange
slice(unsigned subStart, unsigned subLen,
@@ -162,10 +165,8 @@ class MutableOperandRange {
/// elements attribute, which contains the sizes of the sub ranges.
MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
- /// Returns the value at the given index.
- Value operator[](unsigned index) const {
- return operator OperandRange()[index];
- }
+ /// Returns the OpOperand at the given index.
+ OpOperand &operator[](unsigned index) const;
OperandRange::iterator begin() const {
return static_cast<OperandRange>(*this).begin();
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 006aedced839f99..7f6967f11444f31 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -76,7 +76,7 @@ class SuccessorOperands {
Value operator[](unsigned index) const {
if (isOperandProduced(index))
return Value();
- return forwardedOperands[index - producedOperandCount];
+ return forwardedOperands[index - producedOperandCount].get();
}
/// Get the range of operands that are simply forwarded to the successor.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index e5016c956804688..59ec8ccc0806f6c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -549,22 +549,18 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
bool MaterializeInDestinationOp::bufferizesToMemoryRead(
OpOperand &opOperand, const AnalysisState &state) {
- if (&opOperand == &getOperation()->getOpOperand(0) /*source*/)
- return true;
- return false;
+ return &opOperand == &getSourceMutable();
}
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
OpOperand &opOperand, const AnalysisState &state) {
- if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
- return true;
- return false;
+ return &opOperand == &getDestMutable();
}
AliasingValueList
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
const AnalysisState &state) {
- if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
+ if (&opOperand == &getDestMutable())
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
return {};
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 581e7b0a8ea86a7..f704a5235571183 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -949,7 +949,7 @@ struct FoldReshapeWithGenericOpByExpansion
reshapeOp, "failed preconditions of fusion with producer generic op");
}
- if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) {
+ if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
return rewriter.notifyMatchFailure(reshapeOp,
"fusion blocked by control function");
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 597676a017bf482..6931d386261967d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -509,7 +509,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
// 1. Get the producer of the source (potentially walking through
// `iter_args` of nested `scf.for`)
auto [fusableProducer, destinationIterArg] =
- getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
+ getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
loops);
if (!fusableProducer)
return std::nullopt;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index ecca4dd3394e0ae..ec7a06fd8891710 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
RankedTensorType destType = insertSliceOp.getDestType();
// The source is always read.
- if (&opOperand == &op->getOpOperand(0) /*src*/)
+ if (&opOperand == &insertSliceOp.getSourceMutable())
return true;
// For the destination, it depends...
- assert(&opOperand == &insertSliceOp->getOpOperand(1) && "expected dest");
+ assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
// Dest is not read if it is entirely overwritten. E.g.:
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -851,9 +851,8 @@ struct ReshapeOpInterface
tensor::ReshapeOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- if (&opOperand == &op->getOpOperand(1) /* shape */)
- return true;
- return false;
+ auto reshapeOp = cast<tensor::ReshapeOp>(op);
+ return &opOperand == &reshapeOp.getShapeMutable();
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -915,7 +914,8 @@ struct ParallelInsertSliceOpInterface
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
- return &opOperand == &op->getOpOperand(1) /*dest*/;
+ auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
+ return &opOperand == ¶llelInsertSliceOp.getDestMutable();
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 0cb6a1cd191b161..a9b55cec7659c55 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -437,6 +437,12 @@ MutableOperandRange::MutableOperandRange(
MutableOperandRange::MutableOperandRange(Operation *owner)
: MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
+/// Construct a new mutable range for the given OpOperand.
+MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
+ : MutableOperandRange(opOperand.getOwner(),
+ /*start=*/opOperand.getOperandNumber(),
+ /*length=*/1) {}
+
/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange
MutableOperandRange::slice(unsigned subStart, unsigned subLen,
@@ -517,6 +523,10 @@ void MutableOperandRange::updateLength(unsigned newLength) {
}
}
+OpOperand &MutableOperandRange::operator[](unsigned index) const {
+ return owner->getOpOperand(start + index);
+}
+
//===----------------------------------------------------------------------===//
// MutableOperandRangeRange
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 84f23584e9f30e3..9aab89ed7553600 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -277,7 +277,8 @@ class EdgeMultiplexer {
if (index >= result->second &&
index < result->second + edge.getSuccessor()->getNumArguments()) {
// Original block arguments to the entry block.
- newSuccOperands[index] = successorOperands[index - result->second];
+ newSuccOperands[index] =
+ successorOperands[index - result->second].get();
continue;
}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 00c251936655d71..e3d86b4a44d0001 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -998,7 +998,7 @@ void LoopBlockOp::getSuccessorRegions(
OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
assert(point == getBody());
- return getInitMutable();
+ return MutableOperandRange(getInitMutable());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index ad4f53c5af3cff4..df1d13d3bf5580d 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2071,29 +2071,36 @@ void OpEmitter::genNamedOperandSetters() {
continue;
std::string name = op.getGetterName(operand.name);
- auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
- ? "::mlir::MutableOperandRangeRange"
- : "::mlir::MutableOperandRange",
- name + "Mutable");
+ StringRef returnType;
+ if (operand.isVariadicOfVariadic()) {
+ returnType = "::mlir::MutableOperandRangeRange";
+ } else if (operand.isVariableLength()) {
+ returnType = "::mlir::MutableOperandRange";
+ } else {
+ returnType = "::mlir::OpOperand &";
+ }
+ auto *m = opClass.addMethod(returnType, name + "Mutable");
ERROR_IF_PRUNED(m, name, op);
auto &body = m->body();
- body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
- << " auto mutableRange = "
- "::mlir::MutableOperandRange(getOperation(), "
- "range.first, range.second";
- if (attrSizedOperands) {
- if (emitHelper.hasProperties())
- body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
- "{{getOperandSegmentSizesAttrName(), "
- "::mlir::DenseI32ArrayAttr::get(getContext(), "
- "getProperties().operandSegmentSizes)})",
- i);
- else
- body << formatv(
- ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
- emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+ body << " auto range = getODSOperandIndexAndLength(" << i << ");\n";
+ if (operand.isVariadicOfVariadic() || operand.isVariableLength()) {
+ body << " auto mutableRange = "
+ "::mlir::MutableOperandRange(getOperation(), "
+ "range.first, range.second";
+ if (attrSizedOperands) {
+ if (emitHelper.hasProperties())
+ body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
+ "{{getOperandSegmentSizesAttrName(), "
+ "::mlir::DenseI32ArrayAttr::get(getContext(), "
+ "getProperties().operandSegmentSizes)})",
+ i);
+ else
+ body << formatv(
+ ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
+ emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+ }
+ body << ");\n";
}
- body << ");\n";
// If this operand is a nested variadic, we split the range into a
// MutableOperandRangeRange that provides a range over all of the
@@ -2104,9 +2111,13 @@ void OpEmitter::genNamedOperandSetters() {
<< op.getGetterName(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
<< "AttrName()));\n";
- } else {
- // Otherwise, we use the full range directly.
+ } else if (operand.isVariableLength()) {
+ // Otherwise, if the operand has variable length, we use the full range
+ // directly.
body << " return mutableRange;\n";
+ } else {
+ // In case of a single operand, return a single OpOperand.
+ body << " return getOperation()->getOpOperand(range.first);\n";
}
}
}
|
get...Mutable
returns OpOperand &
for single operands
24f3431
to
04dc3b3
Compare
get...Mutable
returns OpOperand &
for single operandsget...Mutable
returns OpOperand &
for single operands
04dc3b3
to
ee804c9
Compare
get...Mutable
returns OpOperand &
for single operandsget...Mutable
to return OpOperand &
for single operands
The TableGen code generator now generates C++ code that returns a single `OpOperand &` for `get...Mutable` of operands that are not variadic and not optional. `OpOperand::set`/`assign` can be used to set a value (same as `MutableOperandRange::assign`). It is safer than `MutableOperandRange` only single values (and no longer `ValueRange`) can be assigned. E.g.: ``` // Before: Assign multiple values to non-variadic operand (forbidden, but // compiles). // After: Compilation error. extractSliceOp.getSourceMutable().assign({v1, v2}); ``` BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC
ee804c9
to
7e5e137
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good in general, looks like it'll avoid a few more footguns. Thanks!
The TableGen code generator now generates C++ code that returns a single
OpOperand &
forget...Mutable
of operands that are not variadic and not optional.OpOperand::set
/assign
can be used to set a value (same asMutableOperandRange::assign
). This is safer thanMutableOperandRange
because only single values (and no longerValueRange
) can be assigned.E.g.: