Skip to content

Commit 2a5d497

Browse files
[mlir][python] Add __{bool,float,int,str}__ to bindings of attributes.
This allows to use Python's `bool(.)`, `float(.)`, `int(.)`, and `str(.)` to convert pybound attributes to the corresponding native Python types. In particular, pybind11 uses these functions to automatically cast objects to the corresponding primitive types wherever they are required by pybound functions, e.g., arguments are converted to Python's `int` if the C++ signature requires a C++ `int`. With this patch, pybound attributes can by used wherever the corresponding native types are expected. New tests show-case this behavior in the constructors of `Dense*ArrayAttr`. Note that this changes the output of Python's `str` on `StringAttr` from `"hello"` to `hello`. Arguably, this is still in line with `str`s goal of producing a readable interpretation of the value, even if it is now not unambiously a string anymore (`print(ir.Attribute.parse('"42"'))` now outputs `42`). However, this is consistent with instances of Python's `str` (`print("42")` outputs `42`), and `repr` still provides an unambigous representation if one is required. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D158974
1 parent 72079d9 commit 2a5d497

File tree

2 files changed

+123
-38
lines changed

2 files changed

+123
-38
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -389,12 +389,10 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
389389
},
390390
py::arg("value"), py::arg("context") = py::none(),
391391
"Gets an uniqued float point attribute associated to a f64 type");
392-
c.def_property_readonly(
393-
"value",
394-
[](PyFloatAttribute &self) {
395-
return mlirFloatAttrGetValueDouble(self);
396-
},
397-
"Returns the value of the float point attribute");
392+
c.def_property_readonly("value", mlirFloatAttrGetValueDouble,
393+
"Returns the value of the float attribute");
394+
c.def("__float__", mlirFloatAttrGetValueDouble,
395+
"Converts the value of the float attribute to a Python float");
398396
}
399397
};
400398

@@ -414,22 +412,25 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
414412
},
415413
py::arg("type"), py::arg("value"),
416414
"Gets an uniqued integer attribute associated to a type");
417-
c.def_property_readonly(
418-
"value",
419-
[](PyIntegerAttribute &self) -> py::int_ {
420-
MlirType type = mlirAttributeGetType(self);
421-
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
422-
return mlirIntegerAttrGetValueInt(self);
423-
if (mlirIntegerTypeIsSigned(type))
424-
return mlirIntegerAttrGetValueSInt(self);
425-
return mlirIntegerAttrGetValueUInt(self);
426-
},
427-
"Returns the value of the integer attribute");
415+
c.def_property_readonly("value", toPyInt,
416+
"Returns the value of the integer attribute");
417+
c.def("__int__", toPyInt,
418+
"Converts the value of the integer attribute to a Python int");
428419
c.def_property_readonly_static("static_typeid",
429420
[](py::object & /*class*/) -> MlirTypeID {
430421
return mlirIntegerAttrGetTypeID();
431422
});
432423
}
424+
425+
private:
426+
static py::int_ toPyInt(PyIntegerAttribute &self) {
427+
MlirType type = mlirAttributeGetType(self);
428+
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
429+
return mlirIntegerAttrGetValueInt(self);
430+
if (mlirIntegerTypeIsSigned(type))
431+
return mlirIntegerAttrGetValueSInt(self);
432+
return mlirIntegerAttrGetValueUInt(self);
433+
}
433434
};
434435

435436
/// Bool Attribute subclass - BoolAttr.
@@ -448,10 +449,10 @@ class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
448449
},
449450
py::arg("value"), py::arg("context") = py::none(),
450451
"Gets an uniqued bool attribute");
451-
c.def_property_readonly(
452-
"value",
453-
[](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
454-
"Returns the value of the bool attribute");
452+
c.def_property_readonly("value", mlirBoolAttrGetValue,
453+
"Returns the value of the bool attribute");
454+
c.def("__bool__", mlirBoolAttrGetValue,
455+
"Converts the value of the bool attribute to a Python bool");
455456
}
456457
};
457458

@@ -595,20 +596,23 @@ class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
595596
},
596597
py::arg("type"), py::arg("value"),
597598
"Gets a uniqued string attribute associated to a type");
598-
c.def_property_readonly(
599-
"value",
600-
[](PyStringAttribute &self) {
601-
MlirStringRef stringRef = mlirStringAttrGetValue(self);
602-
return py::str(stringRef.data, stringRef.length);
603-
},
604-
"Returns the value of the string attribute");
599+
c.def_property_readonly("value", toPyStr,
600+
"Returns the value of the string attribute");
605601
c.def_property_readonly(
606602
"value_bytes",
607603
[](PyStringAttribute &self) {
608604
MlirStringRef stringRef = mlirStringAttrGetValue(self);
609605
return py::bytes(stringRef.data, stringRef.length);
610606
},
611607
"Returns the value of the string attribute as `bytes`");
608+
c.def("__str__", toPyStr,
609+
"Converts the value of the string attribute to a Python str");
610+
}
611+
612+
private:
613+
static py::str toPyStr(PyStringAttribute &self) {
614+
MlirStringRef stringRef = mlirStringAttrGetValue(self);
615+
return py::str(stringRef.data, stringRef.length);
612616
}
613617
};
614618

mlir/test/python/ir/attributes.py

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def testParsePrint():
2121
assert t.context is ctx
2222
ctx = None
2323
gc.collect()
24-
# CHECK: "hello"
24+
# CHECK: hello
2525
print(str(t))
2626
# CHECK: StringAttr("hello")
2727
print(repr(t))
@@ -169,6 +169,8 @@ def testFloatAttr():
169169
fattr = FloatAttr(Attribute.parse("42.0 : f32"))
170170
# CHECK: fattr value: 42.0
171171
print("fattr value:", fattr.value)
172+
# CHECK: fattr float: 42.0 <class 'float'>
173+
print("fattr float:", float(fattr), type(float(fattr)))
172174

173175
# Test factory methods.
174176
# CHECK: default_get: 4.200000e+01 : f32
@@ -196,15 +198,23 @@ def testIntegerAttr():
196198
print("i_attr value:", i_attr.value)
197199
# CHECK: i_attr type: i64
198200
print("i_attr type:", i_attr.type)
201+
# CHECK: i_attr int: 42 <class 'int'>
202+
print("i_attr int:", int(i_attr), type(int(i_attr)))
199203
si_attr = IntegerAttr(Attribute.parse("-1 : si8"))
200204
# CHECK: si_attr value: -1
201205
print("si_attr value:", si_attr.value)
202206
ui_attr = IntegerAttr(Attribute.parse("255 : ui8"))
207+
# CHECK: i_attr int: -1 <class 'int'>
208+
print("si_attr int:", int(si_attr), type(int(si_attr)))
203209
# CHECK: ui_attr value: 255
204210
print("ui_attr value:", ui_attr.value)
211+
# CHECK: i_attr int: 255 <class 'int'>
212+
print("ui_attr int:", int(ui_attr), type(int(ui_attr)))
205213
idx_attr = IntegerAttr(Attribute.parse("-1 : index"))
206214
# CHECK: idx_attr value: -1
207215
print("idx_attr value:", idx_attr.value)
216+
# CHECK: idx_attr int: -1 <class 'int'>
217+
print("idx_attr int:", int(idx_attr), type(int(idx_attr)))
208218

209219
# Test factory methods.
210220
# CHECK: default_get: 42 : i32
@@ -218,6 +228,8 @@ def testBoolAttr():
218228
battr = BoolAttr(Attribute.parse("true"))
219229
# CHECK: iattr value: True
220230
print("iattr value:", battr.value)
231+
# CHECK: iattr bool: True <class 'bool'>
232+
print("iattr bool:", bool(battr), type(bool(battr)))
221233

222234
# Test factory methods.
223235
# CHECK: default_get: true
@@ -278,14 +290,25 @@ def testStringAttr():
278290
sattr = StringAttr(Attribute.parse('"stringattr"'))
279291
# CHECK: sattr value: stringattr
280292
print("sattr value:", sattr.value)
281-
# CHECK: sattr value: b'stringattr'
282-
print("sattr value:", sattr.value_bytes)
293+
# CHECK: sattr value_bytes: b'stringattr'
294+
print("sattr value_bytes:", sattr.value_bytes)
295+
# CHECK: sattr str: stringattr
296+
print("sattr str:", str(sattr))
297+
298+
typed_sattr = StringAttr(Attribute.parse('"stringattr" : i32'))
299+
# CHECK: typed_sattr value: stringattr
300+
print("typed_sattr value:", typed_sattr.value)
301+
# CHECK: typed_sattr str: stringattr
302+
print("typed_sattr str:", str(typed_sattr))
283303

284304
# Test factory methods.
285-
# CHECK: default_get: "foobar"
286-
print("default_get:", StringAttr.get("foobar"))
287-
# CHECK: typed_get: "12345" : i32
288-
print("typed_get:", StringAttr.get_typed(IntegerType.get_signless(32), "12345"))
305+
# CHECK: default_get: StringAttr("foobar")
306+
print("default_get:", repr(StringAttr.get("foobar")))
307+
# CHECK: typed_get: StringAttr("12345" : i32)
308+
print(
309+
"typed_get:",
310+
repr(StringAttr.get_typed(IntegerType.get_signless(32), "12345")),
311+
)
289312

290313

291314
# CHECK-LABEL: TEST: testNamedAttr
@@ -294,8 +317,8 @@ def testNamedAttr():
294317
with Context():
295318
a = Attribute.parse('"stringattr"')
296319
named = a.get_named("foobar") # Note: under the small object threshold
297-
# CHECK: attr: "stringattr"
298-
print("attr:", named.attr)
320+
# CHECK: attr: StringAttr("stringattr")
321+
print("attr:", repr(named.attr))
299322
# CHECK: name: foobar
300323
print("name:", named.name)
301324
# CHECK: named: NamedAttribute(foobar="stringattr")
@@ -367,6 +390,65 @@ def __bool__(self):
367390
print("myboolarray:", DenseBoolArrayAttr.get([MyBool()]))
368391

369392

393+
# CHECK-LABEL: TEST: testDenseArrayAttrConstruction
394+
@run
395+
def testDenseArrayAttrConstruction():
396+
with Context(), Location.unknown():
397+
398+
def create_and_print(cls, x):
399+
try:
400+
darr = cls.get(x)
401+
print(f"input: {x} ({type(x)}), result: {darr}")
402+
except Exception as ex:
403+
print(f"input: {x} ({type(x)}), error: {ex}")
404+
405+
# CHECK: input: [4, 2] (<class 'list'>),
406+
# CHECK-SAME: result: array<i8: 4, 2>
407+
create_and_print(DenseI8ArrayAttr, [4, 2])
408+
409+
# CHECK: input: [4, 2.0] (<class 'list'>),
410+
# CHECK-SAME: error: get(): incompatible function arguments
411+
create_and_print(DenseI8ArrayAttr, [4, 2.0])
412+
413+
# CHECK: input: [40000, 2] (<class 'list'>),
414+
# CHECK-SAME: error: get(): incompatible function arguments
415+
create_and_print(DenseI8ArrayAttr, [40000, 2])
416+
417+
# CHECK: input: range(0, 4) (<class 'range'>),
418+
# CHECK-SAME: result: array<i8: 0, 1, 2, 3>
419+
create_and_print(DenseI8ArrayAttr, range(4))
420+
421+
# CHECK: input: [IntegerAttr(4 : i64), IntegerAttr(2 : i64)] (<class 'list'>),
422+
# CHECK-SAME: result: array<i8: 4, 2>
423+
create_and_print(DenseI8ArrayAttr, [Attribute.parse(f"{x}") for x in [4, 2]])
424+
425+
# CHECK: input: [IntegerAttr(4000 : i64), IntegerAttr(2 : i64)] (<class 'list'>),
426+
# CHECK-SAME: error: get(): incompatible function arguments
427+
create_and_print(DenseI8ArrayAttr, [Attribute.parse(f"{x}") for x in [4000, 2]])
428+
429+
# CHECK: input: [IntegerAttr(4 : i64), FloatAttr(2.000000e+00 : f64)] (<class 'list'>),
430+
# CHECK-SAME: error: get(): incompatible function arguments
431+
create_and_print(DenseI8ArrayAttr, [Attribute.parse(f"{x}") for x in [4, 2.0]])
432+
433+
# CHECK: input: [IntegerAttr(4 : i8), IntegerAttr(2 : ui16)] (<class 'list'>),
434+
# CHECK-SAME: result: array<i8: 4, 2>
435+
create_and_print(
436+
DenseI8ArrayAttr, [Attribute.parse(s) for s in ["4 : i8", "2 : ui16"]]
437+
)
438+
439+
# CHECK: input: [FloatAttr(4.000000e+00 : f64), FloatAttr(2.000000e+00 : f64)] (<class 'list'>)
440+
# CHECK-SAME: result: array<f32: 4.000000e+00, 2.000000e+00>
441+
create_and_print(
442+
DenseF32ArrayAttr, [Attribute.parse(f"{x}") for x in [4.0, 2.0]]
443+
)
444+
445+
# CHECK: [BoolAttr(true), BoolAttr(false)] (<class 'list'>),
446+
# CHECK-SAME: result: array<i1: true, false>
447+
create_and_print(
448+
DenseBoolArrayAttr, [Attribute.parse(f"{x}") for x in ["true", "false"]]
449+
)
450+
451+
370452
# CHECK-LABEL: TEST: testDenseIntAttrGetItem
371453
@run
372454
def testDenseIntAttrGetItem():
@@ -620,7 +702,6 @@ def print_container_item(attr_asm):
620702
@run
621703
def testConcreteAttributesRoundTrip():
622704
with Context(), Location.unknown():
623-
624705
# CHECK: FloatAttr(4.200000e+01 : f32)
625706
print(repr(Attribute.parse("42.0 : f32")))
626707

0 commit comments

Comments
 (0)