Skip to content

[mlir,python] Expose replaceAllUsesExcept to Python bindings #115850

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,15 @@ MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of,
MlirValue with);

/// Replace all uses of 'of' value with 'with' value, updating anything in the
/// IR that uses 'of' to use 'with' instead, except if the user is listed in
/// 'exceptions'. The 'exceptions' parameter is an array of MlirOperation
/// pointers with a length of 'numExceptions'.
MLIR_CAPI_EXPORTED void
mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with,
intptr_t numExceptions,
MlirOperation *exceptions);

//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
Expand Down
29 changes: 29 additions & 0 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ static const char kValueReplaceAllUsesWithDocstring[] =
the IR that uses 'self' to use the other value instead.
)";

static const char kValueReplaceAllUsesExceptDocstring[] =
R"("Replace all uses of this value with the 'with' value, except for those
in 'exceptions'. 'exceptions' can be either a single operation or a list of
operations.
)";

//------------------------------------------------------------------------------
// Utilities.
//------------------------------------------------------------------------------
Expand Down Expand Up @@ -3718,6 +3724,29 @@ void mlir::python::populateIRCore(py::module &m) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
kValueReplaceAllUsesWithDocstring)
.def(
"replace_all_uses_except",
[](MlirValue self, MlirValue with, PyOperation &exception) {
MlirOperation exceptedUser = exception.get();
mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
},
py::arg("with"), py::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
.def(
"replace_all_uses_except",
[](MlirValue self, MlirValue with, py::list exceptions) {
// Convert Python list to a SmallVector of MlirOperations
llvm::SmallVector<MlirOperation> exceptionOps;
for (py::handle exception : exceptions) {
exceptionOps.push_back(exception.cast<PyOperation &>().get());
}

mlirValueReplaceAllUsesExcept(
self, with, static_cast<intptr_t>(exceptionOps.size()),
exceptionOps.data());
},
py::arg("with"), py::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyValue &self) { return self.maybeDownCast(); });
PyBlockArgument::bind(m);
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/CAPI/IR/IR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/ThreadPool.h"

#include <cstddef>
Expand Down Expand Up @@ -1009,6 +1010,20 @@ void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
}

void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
intptr_t numExceptions,
MlirOperation *exceptions) {
Value oldValueCpp = unwrap(oldValue);
Value newValueCpp = unwrap(newValue);

llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet;
for (intptr_t i = 0; i < numExceptions; ++i) {
exceptionSet.insert(unwrap(exceptions[i]));
}

oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
}

//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
Expand Down
71 changes: 71 additions & 0 deletions mlir/test/python/ir/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,77 @@ def testValueReplaceAllUsesWith():
print(f"Use operand_number: {use.operand_number}")


# CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept
@run
def testValueReplaceAllUsesWithExcept():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
value = Operation.create("custom.op1", results=[i32]).results[0]
op1 = Operation.create("custom.op1", operands=[value])
op2 = Operation.create("custom.op2", operands=[value])
value2 = Operation.create("custom.op3", results=[i32]).results[0]
value.replace_all_uses_except(value2, op1)

assert len(list(value.uses)) == 1

# CHECK: Use owner: "custom.op2"
# CHECK: Use operand_number: 0
for use in value2.uses:
assert use.owner in [op2]
print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}")

# CHECK: Use owner: "custom.op1"
# CHECK: Use operand_number: 0
for use in value.uses:
assert use.owner in [op1]
print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}")


# CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions
@run
def testValueReplaceAllUsesWithMultipleExceptions():
ctx = Context()
ctx.allow_unregistered_dialects = True
with Location.unknown(ctx):
i32 = IntegerType.get_signless(32)
module = Module.create()
with InsertionPoint(module.body):
value = Operation.create("custom.op1", results=[i32]).results[0]
op1 = Operation.create("custom.op1", operands=[value])
op2 = Operation.create("custom.op2", operands=[value])
op3 = Operation.create("custom.op3", operands=[value])
value2 = Operation.create("custom.op4", results=[i32]).results[0]

# Replace all uses of `value` with `value2`, except for `op1` and `op2`.
value.replace_all_uses_except(value2, [op1, op2])

# After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`.
assert len(list(value.uses)) == 2
assert len(list(value2.uses)) == 1

# CHECK: Use owner: "custom.op3"
# CHECK: Use operand_number: 0
for use in value2.uses:
assert use.owner in [op3]
print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}")

# CHECK: Use owner: "custom.op2"
# CHECK: Use operand_number: 0
# CHECK: Use owner: "custom.op1"
# CHECK: Use operand_number: 0
for use in value.uses:
assert use.owner in [op1, op2]
print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}")


# CHECK-LABEL: TEST: testValuePrintAsOperand
@run
def testValuePrintAsOperand():
Expand Down
Loading