Skip to content

Commit 5558504

Browse files
[mlir][IR] Make OpOperand comparable (#70410)
Two `OpOperand`s are the same if they belong to the same owner and have the same operand number. There are currently no comparison operators defined on `OpOperand` and we work around this in multiple places by comparing pointers. Note: `OpOperand`s are stored in an op, so it is valid to compare their pointers to determine if they are the same operand. E.g., `getOperandNumber` is also implemented via pointer arithmetics.
1 parent 8d30e80 commit 5558504

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

mlir/include/mlir/IR/UseDefLists.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,16 @@ class IROperand : public detail::IROperandBase {
146146
return *this;
147147
}
148148

149+
/// Two operands are equal if they have the same owner and the same operand
150+
/// number. They are stored inside of ops, so it is valid to compare their
151+
/// pointers to determine equality.
152+
bool operator==(const IROperand<DerivedT, IRValueT> &other) const {
153+
return this == &other;
154+
}
155+
bool operator!=(const IROperand<DerivedT, IRValueT> &other) const {
156+
return !(*this == other);
157+
}
158+
149159
/// Return the current value being used by this operand.
150160
IRValueT get() const { return value; }
151161

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -537,12 +537,12 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
537537

538538
bool MaterializeInDestinationOp::bufferizesToMemoryRead(
539539
OpOperand &opOperand, const AnalysisState &state) {
540-
return &opOperand == &getSourceMutable();
540+
return opOperand == getSourceMutable();
541541
}
542542

543543
bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
544544
OpOperand &opOperand, const AnalysisState &state) {
545-
if (&opOperand == &getDestMutable()) {
545+
if (opOperand == getDestMutable()) {
546546
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
547547
return true;
548548
}
@@ -560,7 +560,7 @@ bool MaterializeInDestinationOp::mustBufferizeInPlace(
560560
AliasingValueList
561561
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
562562
const AnalysisState &state) {
563-
if (&opOperand == &getDestMutable()) {
563+
if (opOperand == getDestMutable()) {
564564
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
565565
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
566566
}

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
644644
RankedTensorType destType = insertSliceOp.getDestType();
645645

646646
// The source is always read.
647-
if (&opOperand == &insertSliceOp.getSourceMutable())
647+
if (opOperand == insertSliceOp.getSourceMutable())
648648
return true;
649649

650650
// For the destination, it depends...
651-
assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
651+
assert(opOperand == insertSliceOp.getDestMutable() && "expected dest");
652652

653653
// Dest is not read if it is entirely overwritten. E.g.:
654654
// tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -849,7 +849,7 @@ struct ReshapeOpInterface
849849
const AnalysisState &state) const {
850850
// Depending on the layout map, the source buffer may have to be copied.
851851
auto reshapeOp = cast<tensor::ReshapeOp>(op);
852-
return &opOperand == &reshapeOp.getShapeMutable();
852+
return opOperand == reshapeOp.getShapeMutable();
853853
}
854854

855855
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -931,7 +931,7 @@ struct ParallelInsertSliceOpInterface
931931
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
932932
const AnalysisState &state) const {
933933
auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
934-
return &opOperand == &parallelInsertSliceOp.getDestMutable();
934+
return opOperand == parallelInsertSliceOp.getDestMutable();
935935
}
936936

937937
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,

0 commit comments

Comments
 (0)