Skip to content

[mlir python] Change PyOpView constructor to construct operations. #123777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 60 additions & 13 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
return mlirStringRefCreate(s.data(), s.size());
}

static MlirStringRef toMlirStringRef(std::string_view s) {
return mlirStringRefCreate(s.data(), s.size());
}

static MlirStringRef toMlirStringRef(const nb::bytes &s) {
return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
}
Expand Down Expand Up @@ -1460,7 +1464,7 @@ static void maybeInsertOperation(PyOperationRef &op,
}
}

nb::object PyOperation::create(const std::string &name,
nb::object PyOperation::create(std::string_view name,
std::optional<std::vector<PyType *>> results,
std::optional<std::vector<PyValue *>> operands,
std::optional<nb::dict> attributes,
Expand Down Expand Up @@ -1506,7 +1510,7 @@ nb::object PyOperation::create(const std::string &name,
} catch (nb::cast_error &err) {
std::string msg = "Invalid attribute key (not a string) when "
"attempting to create the operation \"" +
name + "\" (" + err.what() + ")";
std::string(name) + "\" (" + err.what() + ")";
throw nb::type_error(msg.c_str());
}
try {
Expand All @@ -1516,13 +1520,14 @@ nb::object PyOperation::create(const std::string &name,
} catch (nb::cast_error &err) {
std::string msg = "Invalid attribute value for the key \"" + key +
"\" when attempting to create the operation \"" +
name + "\" (" + err.what() + ")";
std::string(name) + "\" (" + err.what() + ")";
throw nb::type_error(msg.c_str());
} catch (std::runtime_error &) {
// This exception seems thrown when the value is "None".
std::string msg =
"Found an invalid (`None`?) attribute value for the key \"" + key +
"\" when attempting to create the operation \"" + name + "\"";
"\" when attempting to create the operation \"" +
std::string(name) + "\"";
throw std::runtime_error(msg);
}
}
Expand Down Expand Up @@ -1714,27 +1719,25 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
}

nb::object PyOpView::buildGeneric(
const nb::object &cls, std::optional<nb::list> resultTypeList,
nb::list operandList, std::optional<nb::dict> attributes,
std::string_view name, std::tuple<int, bool> opRegionSpec,
nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
const nb::object &maybeIp) {
PyMlirContextRef context = location->getContext();

// Class level operation construction metadata.
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
// Operand and result segment specs are either none, which does no
// variadic unpacking, or a list of ints with segment sizes, where each
// element is either a positive number (typically 1 for a scalar) or -1 to
// indicate that it is derived from the length of the same-indexed operand
// or result (implying that it is a list at that position).
nb::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
nb::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");

std::vector<int32_t> operandSegmentLengths;
std::vector<int32_t> resultSegmentLengths;

// Validate/determine region count.
auto opRegionSpec = nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
int opMinRegionCount = std::get<0>(opRegionSpec);
bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
if (!regions) {
Expand Down Expand Up @@ -3236,6 +3239,33 @@ void mlir::python::populateIRCore(nb::module_ &m) {
auto opViewClass =
nb::class_<PyOpView, PyOperationBase>(m, "OpView")
.def(nb::init<nb::object>(), nb::arg("operation"))
.def(
"__init__",
[](PyOpView *self, std::string_view name,
std::tuple<int, bool> opRegionSpec,
nb::object operandSegmentSpecObj,
nb::object resultSegmentSpecObj,
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
const nb::object &maybeIp) {
new (self) PyOpView(PyOpView::buildGeneric(
name, opRegionSpec, operandSegmentSpecObj,
resultSegmentSpecObj, resultTypeList, operandList,
attributes, successors, regions, location, maybeIp));
},
nb::arg("name"), nb::arg("opRegionSpec"),
nb::arg("operandSegmentSpecObj").none() = nb::none(),
nb::arg("resultSegmentSpecObj").none() = nb::none(),
nb::arg("results").none() = nb::none(),
nb::arg("operands").none() = nb::none(),
nb::arg("attributes").none() = nb::none(),
nb::arg("successors").none() = nb::none(),
nb::arg("regions").none() = nb::none(),
nb::arg("loc").none() = nb::none(),
nb::arg("ip").none() = nb::none())

.def_prop_ro("operation", &PyOpView::getOperationObject)
.def_prop_ro("opview", [](nb::object self) { return self; })
.def(
Expand All @@ -3250,9 +3280,26 @@ void mlir::python::populateIRCore(nb::module_ &m) {
opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
// It is faster to pass the operation_name, ods_regions, and
// ods_operand_segments/ods_result_segments as arguments to the constructor,
// rather than to access them as attributes.
opViewClass.attr("build_generic") = classmethod(
&PyOpView::buildGeneric, nb::arg("cls"),
nb::arg("results").none() = nb::none(),
[](nb::handle cls, std::optional<nb::list> resultTypeList,
nb::list operandList, std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
const nb::object &maybeIp) {
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
std::tuple<int, bool> opRegionSpec =
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
resultSegmentSpec, resultTypeList,
operandList, attributes, successors,
regions, location, maybeIp);
},
nb::arg("cls"), nb::arg("results").none() = nb::none(),
nb::arg("operands").none() = nb::none(),
nb::arg("attributes").none() = nb::none(),
nb::arg("successors").none() = nb::none(),
Expand Down
18 changes: 11 additions & 7 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {

/// Creates an operation. See corresponding python docstring.
static nanobind::object
create(const std::string &name, std::optional<std::vector<PyType *>> results,
create(std::string_view name, std::optional<std::vector<PyType *>> results,
std::optional<std::vector<PyValue *>> operands,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
Expand Down Expand Up @@ -739,12 +739,16 @@ class PyOpView : public PyOperationBase {

nanobind::object getOperationObject() { return operationObject; }

static nanobind::object buildGeneric(
const nanobind::object &cls, std::optional<nanobind::list> resultTypeList,
nanobind::list operandList, std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
const nanobind::object &maybeIp);
static nanobind::object
buildGeneric(std::string_view name, std::tuple<int, bool> opRegionSpec,
nanobind::object operandSegmentSpecObj,
nanobind::object resultSegmentSpecObj,
std::optional<nanobind::list> resultTypeList,
nanobind::list operandList,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
std::optional<int> regions, DefaultingPyLocation location,
const nanobind::object &maybeIp);

/// Construct an instance of a class deriving from OpView, bypassing its
/// `__init__` method. The derived class will typically define a constructor
Expand Down
Loading
Loading