Skip to content

[mlir][ODS] Change get...Mutable to return OpOperand & for single operands #66519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Sep 15, 2023

The TableGen code generator now generates C++ code that returns a single OpOperand & for get...Mutable of operands that are not variadic and not optional. OpOperand::set/assign can be used to set a value (same as MutableOperandRange::assign). This is safer than MutableOperandRange because only single values (and no longer ValueRange) can be assigned.

E.g.:

// Assignment of multiple values to non-variadic operand.
// Before: Compiles, but produces invalid op.
// After: Compilation error.
extractSliceOp.getSourceMutable().assign({v1, v2});

@llvmbot
Copy link
Member

llvmbot commented Sep 15, 2023

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-tensor

Changes

The TableGen code generator now generates C++ code that returns a single OpOperand & for get...Mutable of operands that are not variadic and not optional. OpOperand::set/assign can be used to set a value (same as MutableOperandRange::assign). It is safer than MutableOperandRange because only single values (and no longer ValueRange) can be assigned.

E.g.:

// Before: Assign multiple values to non-variadic operand (forbidden, but
//         compiles).
// After: Compilation error.
extractSliceOp.getSourceMutable().assign({v1, v2});

Depends on #66515. Review only the top commit.


Full diff: https://github.com/llvm/llvm-project/pull/66519.diff

11 Files Affected:

  • (modified) mlir/include/mlir/IR/Value.h (+3)
  • (modified) mlir/include/mlir/IR/ValueRange.h (+5-4)
  • (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+1-1)
  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+3-7)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+1-1)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+6-6)
  • (modified) mlir/lib/IR/OperationSupport.cpp (+10)
  • (modified) mlir/lib/Transforms/Utils/CFGToSCF.cpp (+2-1)
  • (modified) mlir/test/lib/Dialect/Test/TestDialect.cpp (+1-1)
  • (modified) mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp (+33-22)
diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h
index 51d4e366e4970d5..4e550fe3e3a60e6 100644
--- a/mlir/include/mlir/IR/Value.h
+++ b/mlir/include/mlir/IR/Value.h
@@ -268,6 +268,9 @@ class OpOperand : public IROperand<OpOperand, Value> {
   /// Return which operand this is in the OpOperand list of the Operation.
   unsigned getOperandNumber();
 
+  /// Set the current value being used by this operand.
+  void assign(Value value) { set(value); }
+
 private:
   /// Keep the constructor private and accessible to the OperandStorage class
   /// only to avoid hard-to-debug typo/programming mistakes.
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index 187185b47b66695..4546f0fe4bf48c5 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -126,6 +126,9 @@ class MutableOperandRange {
                       ArrayRef<OperandSegment> operandSegments = std::nullopt);
   MutableOperandRange(Operation *owner);
 
+  /// Construct a new mutable range for the given OpOperand.
+  MutableOperandRange(OpOperand &opOperand);
+
   /// Slice this range into a sub range, with the additional operand segment.
   MutableOperandRange
   slice(unsigned subStart, unsigned subLen,
@@ -162,10 +165,8 @@ class MutableOperandRange {
   /// elements attribute, which contains the sizes of the sub ranges.
   MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
 
-  /// Returns the value at the given index.
-  Value operator[](unsigned index) const {
-    return operator OperandRange()[index];
-  }
+  /// Returns the OpOperand at the given index.
+  OpOperand &operator[](unsigned index) const;
 
   OperandRange::iterator begin() const {
     return static_cast<OperandRange>(*this).begin();
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index 006aedced839f99..7f6967f11444f31 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -76,7 +76,7 @@ class SuccessorOperands {
   Value operator[](unsigned index) const {
     if (isOperandProduced(index))
       return Value();
-    return forwardedOperands[index - producedOperandCount];
+    return forwardedOperands[index - producedOperandCount].get();
   }
 
   /// Get the range of operands that are simply forwarded to the successor.
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index e5016c956804688..59ec8ccc0806f6c 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -549,22 +549,18 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
 
 bool MaterializeInDestinationOp::bufferizesToMemoryRead(
     OpOperand &opOperand, const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(0) /*source*/)
-    return true;
-  return false;
+  return &opOperand == &getSourceMutable();
 }
 
 bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
     OpOperand &opOperand, const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
-    return true;
-  return false;
+  return &opOperand == &getDestMutable();
 }
 
 AliasingValueList
 MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
                                               const AnalysisState &state) {
-  if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/)
+  if (&opOperand == &getDestMutable())
     return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
   return {};
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 581e7b0a8ea86a7..f704a5235571183 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -949,7 +949,7 @@ struct FoldReshapeWithGenericOpByExpansion
           reshapeOp, "failed preconditions of fusion with producer generic op");
     }
 
-    if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) {
+    if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) {
       return rewriter.notifyMatchFailure(reshapeOp,
                                          "fusion blocked by control function");
     }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 597676a017bf482..6931d386261967d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -509,7 +509,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
   auto [fusableProducer, destinationIterArg] =
-      getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0),
+      getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(),
                                         loops);
   if (!fusableProducer)
     return std::nullopt;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index ecca4dd3394e0ae..ec7a06fd8891710 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -644,11 +644,11 @@ struct InsertSliceOpInterface
     RankedTensorType destType = insertSliceOp.getDestType();
 
     // The source is always read.
-    if (&opOperand == &op->getOpOperand(0) /*src*/)
+    if (&opOperand == &insertSliceOp.getSourceMutable())
       return true;
 
     // For the destination, it depends...
-    assert(&opOperand == &insertSliceOp->getOpOperand(1) && "expected dest");
+    assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest");
 
     // Dest is not read if it is entirely overwritten. E.g.:
     // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32>
@@ -851,9 +851,8 @@ struct ReshapeOpInterface
                                                     tensor::ReshapeOp> {
   bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
                               const AnalysisState &state) const {
-    if (&opOperand == &op->getOpOperand(1) /* shape */)
-      return true;
-    return false;
+    auto reshapeOp = cast<tensor::ReshapeOp>(op);
+    return &opOperand == &reshapeOp.getShapeMutable();
   }
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
@@ -915,7 +914,8 @@ struct ParallelInsertSliceOpInterface
 
   bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
                                const AnalysisState &state) const {
-    return &opOperand == &op->getOpOperand(1) /*dest*/;
+    auto parallelInsertSliceOp = cast<ParallelInsertSliceOp>(op);
+    return &opOperand == &parallelInsertSliceOp.getDestMutable();
   }
 
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 0cb6a1cd191b161..a9b55cec7659c55 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -437,6 +437,12 @@ MutableOperandRange::MutableOperandRange(
 MutableOperandRange::MutableOperandRange(Operation *owner)
     : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {}
 
+/// Construct a new mutable range for the given OpOperand.
+MutableOperandRange::MutableOperandRange(OpOperand &opOperand)
+    : MutableOperandRange(opOperand.getOwner(),
+                          /*start=*/opOperand.getOperandNumber(),
+                          /*length=*/1) {}
+
 /// Slice this range into a sub range, with the additional operand segment.
 MutableOperandRange
 MutableOperandRange::slice(unsigned subStart, unsigned subLen,
@@ -517,6 +523,10 @@ void MutableOperandRange::updateLength(unsigned newLength) {
   }
 }
 
+OpOperand &MutableOperandRange::operator[](unsigned index) const {
+  return owner->getOpOperand(start + index);
+}
+
 //===----------------------------------------------------------------------===//
 // MutableOperandRangeRange
 
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 84f23584e9f30e3..9aab89ed7553600 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -277,7 +277,8 @@ class EdgeMultiplexer {
       if (index >= result->second &&
           index < result->second + edge.getSuccessor()->getNumArguments()) {
         // Original block arguments to the entry block.
-        newSuccOperands[index] = successorOperands[index - result->second];
+        newSuccOperands[index] =
+            successorOperands[index - result->second].get();
         continue;
       }
 
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index 00c251936655d71..e3d86b4a44d0001 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -998,7 +998,7 @@ void LoopBlockOp::getSuccessorRegions(
 
 OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) {
   assert(point == getBody());
-  return getInitMutable();
+  return MutableOperandRange(getInitMutable());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index ad4f53c5af3cff4..df1d13d3bf5580d 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -2071,29 +2071,36 @@ void OpEmitter::genNamedOperandSetters() {
       continue;
     std::string name = op.getGetterName(operand.name);
 
-    auto *m = opClass.addMethod(operand.isVariadicOfVariadic()
-                                    ? "::mlir::MutableOperandRangeRange"
-                                    : "::mlir::MutableOperandRange",
-                                name + "Mutable");
+    StringRef returnType;
+    if (operand.isVariadicOfVariadic()) {
+      returnType = "::mlir::MutableOperandRangeRange";
+    } else if (operand.isVariableLength()) {
+      returnType = "::mlir::MutableOperandRange";
+    } else {
+      returnType = "::mlir::OpOperand &";
+    }
+    auto *m = opClass.addMethod(returnType, name + "Mutable");
     ERROR_IF_PRUNED(m, name, op);
     auto &body = m->body();
-    body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n"
-         << "  auto mutableRange = "
-            "::mlir::MutableOperandRange(getOperation(), "
-            "range.first, range.second";
-    if (attrSizedOperands) {
-      if (emitHelper.hasProperties())
-        body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
-                        "{{getOperandSegmentSizesAttrName(), "
-                        "::mlir::DenseI32ArrayAttr::get(getContext(), "
-                        "getProperties().operandSegmentSizes)})",
-                        i);
-      else
-        body << formatv(
-            ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
-            emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+    body << "  auto range = getODSOperandIndexAndLength(" << i << ");\n";
+    if (operand.isVariadicOfVariadic() || operand.isVariableLength()) {
+      body << "  auto mutableRange = "
+              "::mlir::MutableOperandRange(getOperation(), "
+              "range.first, range.second";
+      if (attrSizedOperands) {
+        if (emitHelper.hasProperties())
+          body << formatv(", ::mlir::MutableOperandRange::OperandSegment({0}u, "
+                          "{{getOperandSegmentSizesAttrName(), "
+                          "::mlir::DenseI32ArrayAttr::get(getContext(), "
+                          "getProperties().operandSegmentSizes)})",
+                          i);
+        else
+          body << formatv(
+              ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i,
+              emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true));
+      }
+      body << ");\n";
     }
-    body << ");\n";
 
     // If this operand is a nested variadic, we split the range into a
     // MutableOperandRangeRange that provides a range over all of the
@@ -2104,9 +2111,13 @@ void OpEmitter::genNamedOperandSetters() {
            << op.getGetterName(
                   operand.constraint.getVariadicOfVariadicSegmentSizeAttr())
            << "AttrName()));\n";
-    } else {
-      // Otherwise, we use the full range directly.
+    } else if (operand.isVariableLength()) {
+      // Otherwise, if the operand has variable length, we use the full range
+      // directly.
       body << "  return mutableRange;\n";
+    } else {
+      // In case of a single operand, return a single OpOperand.
+      body << "  return getOperation()->getOpOperand(range.first);\n";
     }
   }
 }

@matthias-springer matthias-springer changed the title [mlir][TblGen] get...Mutable returns OpOperand & for single operands [mlir][TblGen] get...Mutable returns OpOperand & for single operands Sep 15, 2023
@matthias-springer matthias-springer force-pushed the tblgen_get_operand_mutable branch from 24f3431 to 04dc3b3 Compare September 15, 2023 15:23
@matthias-springer matthias-springer changed the title [mlir][TblGen] get...Mutable returns OpOperand & for single operands [mlir][ODS] get...Mutable returns OpOperand & for single operands Sep 17, 2023
@matthias-springer matthias-springer force-pushed the tblgen_get_operand_mutable branch from 04dc3b3 to ee804c9 Compare September 18, 2023 08:01
@matthias-springer matthias-springer changed the title [mlir][ODS] get...Mutable returns OpOperand & for single operands [mlir][ODS] Change get...Mutable to return OpOperand & for single operands Sep 18, 2023
The TableGen code generator now generates C++ code that returns a single `OpOperand &` for `get...Mutable` of operands that are not variadic and not optional. `OpOperand::set`/`assign`  can be used to set a value (same as `MutableOperandRange::assign`). It is safer than `MutableOperandRange` only single values (and no longer `ValueRange`) can be assigned.

E.g.:
```
// Before: Assign multiple values to non-variadic operand (forbidden, but
//         compiles).
// After: Compilation error.
extractSliceOp.getSourceMutable().assign({v1, v2});
```

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
@matthias-springer matthias-springer force-pushed the tblgen_get_operand_mutable branch from ee804c9 to 7e5e137 Compare October 3, 2023 14:25
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general, looks like it'll avoid a few more footguns. Thanks!

@matthias-springer matthias-springer merged commit 8823e96 into llvm:main Oct 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:linalg mlir:scf mlir:tensor mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants