Skip to content

Commit f4125e0

Browse files
authored
[mlir python] Change PyOpView constructor to construct operations. (#123777)
Previously ODS-generated Python operations had code like this: ``` super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)) ``` we change it to: ``` super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip) ``` This: a) avoids an extra call dispatch (to `build_generic`), and b) passes the class attributes directly to the constructor. Benchmarks show that it is faster to pass these as arguments rather than having the C++ code look up attributes on the class. This PR improves the timing of the following benchmark on my workstation from 5.3s to 4.5s: ``` def main(_): with ir.Context(), ir.Location.unknown(): typ = ir.IntegerType.get_signless(32) m = ir.Module.create() with ir.InsertionPoint(m.body): start = time.time() for i in range(1000000): arith.ConstantOp(typ, i) end = time.time() print(f"time: {end - start}") ``` Since this change adds an additional overload to the constructor and does not alter any existing behaviors, it should be backwards compatible.
1 parent 43177b5 commit f4125e0

File tree

4 files changed

+124
-53
lines changed

4 files changed

+124
-53
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
211211
return mlirStringRefCreate(s.data(), s.size());
212212
}
213213

214+
static MlirStringRef toMlirStringRef(std::string_view s) {
215+
return mlirStringRefCreate(s.data(), s.size());
216+
}
217+
214218
static MlirStringRef toMlirStringRef(const nb::bytes &s) {
215219
return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
216220
}
@@ -1460,7 +1464,7 @@ static void maybeInsertOperation(PyOperationRef &op,
14601464
}
14611465
}
14621466

1463-
nb::object PyOperation::create(const std::string &name,
1467+
nb::object PyOperation::create(std::string_view name,
14641468
std::optional<std::vector<PyType *>> results,
14651469
std::optional<std::vector<PyValue *>> operands,
14661470
std::optional<nb::dict> attributes,
@@ -1506,7 +1510,7 @@ nb::object PyOperation::create(const std::string &name,
15061510
} catch (nb::cast_error &err) {
15071511
std::string msg = "Invalid attribute key (not a string) when "
15081512
"attempting to create the operation \"" +
1509-
name + "\" (" + err.what() + ")";
1513+
std::string(name) + "\" (" + err.what() + ")";
15101514
throw nb::type_error(msg.c_str());
15111515
}
15121516
try {
@@ -1516,13 +1520,14 @@ nb::object PyOperation::create(const std::string &name,
15161520
} catch (nb::cast_error &err) {
15171521
std::string msg = "Invalid attribute value for the key \"" + key +
15181522
"\" when attempting to create the operation \"" +
1519-
name + "\" (" + err.what() + ")";
1523+
std::string(name) + "\" (" + err.what() + ")";
15201524
throw nb::type_error(msg.c_str());
15211525
} catch (std::runtime_error &) {
15221526
// This exception seems thrown when the value is "None".
15231527
std::string msg =
15241528
"Found an invalid (`None`?) attribute value for the key \"" + key +
1525-
"\" when attempting to create the operation \"" + name + "\"";
1529+
"\" when attempting to create the operation \"" +
1530+
std::string(name) + "\"";
15261531
throw std::runtime_error(msg);
15271532
}
15281533
}
@@ -1714,27 +1719,25 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
17141719
}
17151720

17161721
nb::object PyOpView::buildGeneric(
1717-
const nb::object &cls, std::optional<nb::list> resultTypeList,
1718-
nb::list operandList, std::optional<nb::dict> attributes,
1722+
std::string_view name, std::tuple<int, bool> opRegionSpec,
1723+
nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
1724+
std::optional<nb::list> resultTypeList, nb::list operandList,
1725+
std::optional<nb::dict> attributes,
17191726
std::optional<std::vector<PyBlock *>> successors,
17201727
std::optional<int> regions, DefaultingPyLocation location,
17211728
const nb::object &maybeIp) {
17221729
PyMlirContextRef context = location->getContext();
1730+
17231731
// Class level operation construction metadata.
1724-
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
17251732
// Operand and result segment specs are either none, which does no
17261733
// variadic unpacking, or a list of ints with segment sizes, where each
17271734
// element is either a positive number (typically 1 for a scalar) or -1 to
17281735
// indicate that it is derived from the length of the same-indexed operand
17291736
// or result (implying that it is a list at that position).
1730-
nb::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1731-
nb::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1732-
17331737
std::vector<int32_t> operandSegmentLengths;
17341738
std::vector<int32_t> resultSegmentLengths;
17351739

17361740
// Validate/determine region count.
1737-
auto opRegionSpec = nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
17381741
int opMinRegionCount = std::get<0>(opRegionSpec);
17391742
bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
17401743
if (!regions) {
@@ -3236,6 +3239,33 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32363239
auto opViewClass =
32373240
nb::class_<PyOpView, PyOperationBase>(m, "OpView")
32383241
.def(nb::init<nb::object>(), nb::arg("operation"))
3242+
.def(
3243+
"__init__",
3244+
[](PyOpView *self, std::string_view name,
3245+
std::tuple<int, bool> opRegionSpec,
3246+
nb::object operandSegmentSpecObj,
3247+
nb::object resultSegmentSpecObj,
3248+
std::optional<nb::list> resultTypeList, nb::list operandList,
3249+
std::optional<nb::dict> attributes,
3250+
std::optional<std::vector<PyBlock *>> successors,
3251+
std::optional<int> regions, DefaultingPyLocation location,
3252+
const nb::object &maybeIp) {
3253+
new (self) PyOpView(PyOpView::buildGeneric(
3254+
name, opRegionSpec, operandSegmentSpecObj,
3255+
resultSegmentSpecObj, resultTypeList, operandList,
3256+
attributes, successors, regions, location, maybeIp));
3257+
},
3258+
nb::arg("name"), nb::arg("opRegionSpec"),
3259+
nb::arg("operandSegmentSpecObj").none() = nb::none(),
3260+
nb::arg("resultSegmentSpecObj").none() = nb::none(),
3261+
nb::arg("results").none() = nb::none(),
3262+
nb::arg("operands").none() = nb::none(),
3263+
nb::arg("attributes").none() = nb::none(),
3264+
nb::arg("successors").none() = nb::none(),
3265+
nb::arg("regions").none() = nb::none(),
3266+
nb::arg("loc").none() = nb::none(),
3267+
nb::arg("ip").none() = nb::none())
3268+
32393269
.def_prop_ro("operation", &PyOpView::getOperationObject)
32403270
.def_prop_ro("opview", [](nb::object self) { return self; })
32413271
.def(
@@ -3250,9 +3280,26 @@ void mlir::python::populateIRCore(nb::module_ &m) {
32503280
opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
32513281
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
32523282
opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
3283+
// It is faster to pass the operation_name, ods_regions, and
3284+
// ods_operand_segments/ods_result_segments as arguments to the constructor,
3285+
// rather than to access them as attributes.
32533286
opViewClass.attr("build_generic") = classmethod(
3254-
&PyOpView::buildGeneric, nb::arg("cls"),
3255-
nb::arg("results").none() = nb::none(),
3287+
[](nb::handle cls, std::optional<nb::list> resultTypeList,
3288+
nb::list operandList, std::optional<nb::dict> attributes,
3289+
std::optional<std::vector<PyBlock *>> successors,
3290+
std::optional<int> regions, DefaultingPyLocation location,
3291+
const nb::object &maybeIp) {
3292+
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
3293+
std::tuple<int, bool> opRegionSpec =
3294+
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
3295+
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
3296+
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
3297+
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
3298+
resultSegmentSpec, resultTypeList,
3299+
operandList, attributes, successors,
3300+
regions, location, maybeIp);
3301+
},
3302+
nb::arg("cls"), nb::arg("results").none() = nb::none(),
32563303
nb::arg("operands").none() = nb::none(),
32573304
nb::arg("attributes").none() = nb::none(),
32583305
nb::arg("successors").none() = nb::none(),

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
685685

686686
/// Creates an operation. See corresponding python docstring.
687687
static nanobind::object
688-
create(const std::string &name, std::optional<std::vector<PyType *>> results,
688+
create(std::string_view name, std::optional<std::vector<PyType *>> results,
689689
std::optional<std::vector<PyValue *>> operands,
690690
std::optional<nanobind::dict> attributes,
691691
std::optional<std::vector<PyBlock *>> successors, int regions,
@@ -739,12 +739,16 @@ class PyOpView : public PyOperationBase {
739739

740740
nanobind::object getOperationObject() { return operationObject; }
741741

742-
static nanobind::object buildGeneric(
743-
const nanobind::object &cls, std::optional<nanobind::list> resultTypeList,
744-
nanobind::list operandList, std::optional<nanobind::dict> attributes,
745-
std::optional<std::vector<PyBlock *>> successors,
746-
std::optional<int> regions, DefaultingPyLocation location,
747-
const nanobind::object &maybeIp);
742+
static nanobind::object
743+
buildGeneric(std::string_view name, std::tuple<int, bool> opRegionSpec,
744+
nanobind::object operandSegmentSpecObj,
745+
nanobind::object resultSegmentSpecObj,
746+
std::optional<nanobind::list> resultTypeList,
747+
nanobind::list operandList,
748+
std::optional<nanobind::dict> attributes,
749+
std::optional<std::vector<PyBlock *>> successors,
750+
std::optional<int> regions, DefaultingPyLocation location,
751+
const nanobind::object &maybeIp);
748752

749753
/// Construct an instance of a class deriving from OpView, bypassing its
750754
/// `__init__` method. The derived class will typically define a constructor

0 commit comments

Comments
 (0)