Skip to content

Commit ca23c93

Browse files
[mlir][python] Create all missing attribute builders.
This patch adds attribute builders for all buildable attributes from the builtin dialect that did not previously have any. These builders can be used to construct attributes of a particular type identified by a string from a Python argument without knowing the details of how to pass that Python argument to the attribute constructor. This is used, for example, in the generated code of the Python bindings of ops. The list of "all" attributes was produced with: ( grep -h "ods_ir.AttrBuilder.get" $(find ../build/ -name "*_ops_gen.py") \ | cut -f2 -d"'" git grep -ho "^def [a-zA-Z0-9_]*" -- include/mlir/IR/CommonAttrConstraints.td \ | cut -f2 -d" " ) | sort -u Then, I only retained those that had an occurence in `mlir/include/mlir/IR`. In particular, this drops many dialect-specific attributes; registering those builders is something that those dialects should do. Finally, I removed those attrbiutes that had a match in `mlir/python/mlir/ir.py` already and implemented the remaining ones. The only ones that still miss a builder now are the following: * Represent more than one possible attribute type: - `Any.*Attr` (9x) - `IntNonNegative` - `IntPositive` - `IsNullAttr` - `ElementsAttr` * I am not sure what "constant attributes" are: - `ConstBoolAttrFalse` - `ConstBoolAttrTrue` - `ConstUnitAttr` * `Location` not exposed by Python bindings: - `LocationArrayAttr` - `LocationAttr` * `get` function not implemented in Python bindings: - `StringElementsAttr` This patch also fixes a compilation problem with `I64SmallVectorArrayAttr`. Reviewed By: makslevental, rkayaith Differential Revision: https://reviews.llvm.org/D159403
1 parent d26c78b commit ca23c93

File tree

4 files changed

+280
-31
lines changed

4 files changed

+280
-31
lines changed

mlir/include/mlir/IR/CommonAttrConstraints.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def I64SmallVectorArrayAttr :
612612
let convertFromStorage = [{
613613
llvm::to_vector<4>(
614614
llvm::map_range($_self.getAsRange<mlir::IntegerAttr>(),
615-
[](IntegerAttr attr) { return attr.getInt(); }));
615+
[](mlir::IntegerAttr attr) { return attr.getInt(); }));
616616
}];
617617
let constBuilderCall = "$_builder.getI64ArrayAttr($0)";
618618
}

mlir/python/mlir/ir.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,36 @@ def decorator_builder(func):
1616
return decorator_builder
1717

1818

19+
@register_attribute_builder("AffineMapAttr")
20+
def _affineMapAttr(x, context):
21+
return AffineMapAttr.get(x)
22+
23+
1924
@register_attribute_builder("BoolAttr")
2025
def _boolAttr(x, context):
2126
return BoolAttr.get(x, context=context)
2227

2328

29+
@register_attribute_builder("DictionaryAttr")
30+
def _dictAttr(x, context):
31+
return DictAttr.get(x, context=context)
32+
33+
2434
@register_attribute_builder("IndexAttr")
2535
def _indexAttr(x, context):
2636
return IntegerAttr.get(IndexType.get(context=context), x)
2737

2838

39+
@register_attribute_builder("I1Attr")
40+
def _i1Attr(x, context):
41+
return IntegerAttr.get(IntegerType.get_signless(1, context=context), x)
42+
43+
44+
@register_attribute_builder("I8Attr")
45+
def _i8Attr(x, context):
46+
return IntegerAttr.get(IntegerType.get_signless(8, context=context), x)
47+
48+
2949
@register_attribute_builder("I16Attr")
3050
def _i16Attr(x, context):
3151
return IntegerAttr.get(IntegerType.get_signless(16, context=context), x)
@@ -41,6 +61,16 @@ def _i64Attr(x, context):
4161
return IntegerAttr.get(IntegerType.get_signless(64, context=context), x)
4262

4363

64+
@register_attribute_builder("SI1Attr")
65+
def _si1Attr(x, context):
66+
return IntegerAttr.get(IntegerType.get_signed(1, context=context), x)
67+
68+
69+
@register_attribute_builder("SI8Attr")
70+
def _i8Attr(x, context):
71+
return IntegerAttr.get(IntegerType.get_signed(8, context=context), x)
72+
73+
4474
@register_attribute_builder("SI16Attr")
4575
def _si16Attr(x, context):
4676
return IntegerAttr.get(IntegerType.get_signed(16, context=context), x)
@@ -51,6 +81,36 @@ def _si32Attr(x, context):
5181
return IntegerAttr.get(IntegerType.get_signed(32, context=context), x)
5282

5383

84+
@register_attribute_builder("SI64Attr")
85+
def _si64Attr(x, context):
86+
return IntegerAttr.get(IntegerType.get_signed(64, context=context), x)
87+
88+
89+
@register_attribute_builder("UI1Attr")
90+
def _ui1Attr(x, context):
91+
return IntegerAttr.get(IntegerType.get_unsigned(1, context=context), x)
92+
93+
94+
@register_attribute_builder("UI8Attr")
95+
def _i8Attr(x, context):
96+
return IntegerAttr.get(IntegerType.get_unsigned(8, context=context), x)
97+
98+
99+
@register_attribute_builder("UI16Attr")
100+
def _ui16Attr(x, context):
101+
return IntegerAttr.get(IntegerType.get_unsigned(16, context=context), x)
102+
103+
104+
@register_attribute_builder("UI32Attr")
105+
def _ui32Attr(x, context):
106+
return IntegerAttr.get(IntegerType.get_unsigned(32, context=context), x)
107+
108+
109+
@register_attribute_builder("UI64Attr")
110+
def _ui64Attr(x, context):
111+
return IntegerAttr.get(IntegerType.get_unsigned(64, context=context), x)
112+
113+
54114
@register_attribute_builder("F32Attr")
55115
def _f32Attr(x, context):
56116
return FloatAttr.get_f32(x, context=context)
@@ -84,11 +144,39 @@ def _flatSymbolRefAttr(x, context):
84144
return FlatSymbolRefAttr.get(x, context=context)
85145

86146

147+
@register_attribute_builder("UnitAttr")
148+
def _unitAttr(x, context):
149+
if x:
150+
return UnitAttr.get(context=context)
151+
else:
152+
return None
153+
154+
87155
@register_attribute_builder("ArrayAttr")
88156
def _arrayAttr(x, context):
89157
return ArrayAttr.get(x, context=context)
90158

91159

160+
@register_attribute_builder("AffineMapArrayAttr")
161+
def _affineMapArrayAttr(x, context):
162+
return ArrayAttr.get([_affineMapAttr(v, context) for v in x])
163+
164+
165+
@register_attribute_builder("BoolArrayAttr")
166+
def _boolArrayAttr(x, context):
167+
return ArrayAttr.get([_boolAttr(v, context) for v in x])
168+
169+
170+
@register_attribute_builder("DictArrayAttr")
171+
def _dictArrayAttr(x, context):
172+
return ArrayAttr.get([_dictAttr(v, context) for v in x])
173+
174+
175+
@register_attribute_builder("FlatSymbolRefArrayAttr")
176+
def _flatSymbolRefArrayAttr(x, context):
177+
return ArrayAttr.get([_flatSymbolRefAttr(v, context) for v in x])
178+
179+
92180
@register_attribute_builder("I32ArrayAttr")
93181
def _i32ArrayAttr(x, context):
94182
return ArrayAttr.get([_i32Attr(v, context) for v in x])
@@ -99,6 +187,16 @@ def _i64ArrayAttr(x, context):
99187
return ArrayAttr.get([_i64Attr(v, context) for v in x])
100188

101189

190+
@register_attribute_builder("I64SmallVectorArrayAttr")
191+
def _i64SmallVectorArrayAttr(x, context):
192+
return _i64ArrayAttr(x, context=context)
193+
194+
195+
@register_attribute_builder("IndexListArrayAttr")
196+
def _indexListArrayAttr(x, context):
197+
return ArrayAttr.get([_i64ArrayAttr(v, context) for v in x])
198+
199+
102200
@register_attribute_builder("F32ArrayAttr")
103201
def _f32ArrayAttr(x, context):
104202
return ArrayAttr.get([_f32Attr(v, context) for v in x])
@@ -109,6 +207,41 @@ def _f64ArrayAttr(x, context):
109207
return ArrayAttr.get([_f64Attr(v, context) for v in x])
110208

111209

210+
@register_attribute_builder("StrArrayAttr")
211+
def _strArrayAttr(x, context):
212+
return ArrayAttr.get([_stringAttr(v, context) for v in x])
213+
214+
215+
@register_attribute_builder("SymbolRefArrayAttr")
216+
def _symbolRefArrayAttr(x, context):
217+
return ArrayAttr.get([_symbolRefAttr(v, context) for v in x])
218+
219+
220+
@register_attribute_builder("DenseF32ArrayAttr")
221+
def _denseF32ArrayAttr(x, context):
222+
return DenseF32ArrayAttr.get(x, context=context)
223+
224+
225+
@register_attribute_builder("DenseF64ArrayAttr")
226+
def _denseF64ArrayAttr(x, context):
227+
return DenseF64ArrayAttr.get(x, context=context)
228+
229+
230+
@register_attribute_builder("DenseI8ArrayAttr")
231+
def _denseI8ArrayAttr(x, context):
232+
return DenseI8ArrayAttr.get(x, context=context)
233+
234+
235+
@register_attribute_builder("DenseI16ArrayAttr")
236+
def _denseI16ArrayAttr(x, context):
237+
return DenseI16ArrayAttr.get(x, context=context)
238+
239+
240+
@register_attribute_builder("DenseI32ArrayAttr")
241+
def _denseI32ArrayAttr(x, context):
242+
return DenseI32ArrayAttr.get(x, context=context)
243+
244+
112245
@register_attribute_builder("DenseI64ArrayAttr")
113246
def _denseI64ArrayAttr(x, context):
114247
return DenseI64ArrayAttr.get(x, context=context)
@@ -132,6 +265,30 @@ def _typeArrayAttr(x, context):
132265
try:
133266
import numpy as np
134267

268+
@register_attribute_builder("F64ElementsAttr")
269+
def _f64ElementsAttr(x, context):
270+
return DenseElementsAttr.get(
271+
np.array(x, dtype=np.int64),
272+
type=F64Type.get(context=context),
273+
context=context,
274+
)
275+
276+
@register_attribute_builder("I32ElementsAttr")
277+
def _i32ElementsAttr(x, context):
278+
return DenseElementsAttr.get(
279+
np.array(x, dtype=np.int32),
280+
type=IntegerType.get_signed(32, context=context),
281+
context=context,
282+
)
283+
284+
@register_attribute_builder("I64ElementsAttr")
285+
def _i64ElementsAttr(x, context):
286+
return DenseElementsAttr.get(
287+
np.array(x, dtype=np.int64),
288+
type=IntegerType.get_signed(64, context=context),
289+
context=context,
290+
)
291+
135292
@register_attribute_builder("IndexElementsAttr")
136293
def _indexElementsAttr(x, context):
137294
return DenseElementsAttr.get(

mlir/test/python/dialects/python_test.py

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -140,23 +140,76 @@ def testAttributes():
140140
def attrBuilder():
141141
with Context() as ctx, Location.unknown():
142142
ctx.allow_unregistered_dialects = True
143+
# CHECK: python_test.attributes_op
143144
op = test.AttributesOp(
144-
x_bool=True,
145-
x_i16=1,
146-
x_i32=2,
147-
x_i64=3,
148-
x_si16=-1,
149-
x_si32=-2,
150-
x_f32=1.5,
151-
x_f64=2.5,
152-
x_str="x_str",
153-
x_i32_array=[1, 2, 3],
154-
x_i64_array=[4, 5, 6],
155-
x_f32_array=[1.5, -2.5, 3.5],
156-
x_f64_array=[4.5, 5.5, -6.5],
157-
x_i64_dense=[1, 2, 3, 4, 5, 6],
145+
# CHECK-DAG: x_affinemap = affine_map<() -> (2)>
146+
x_affinemap=AffineMap.get_constant(2),
147+
# CHECK-DAG: x_affinemaparr = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>]
148+
x_affinemaparr=[AffineMap.get_identity(3)],
149+
# CHECK-DAG: x_arr = [true, "x"]
150+
x_arr=[BoolAttr.get(True), StringAttr.get("x")],
151+
x_boolarr=[False, True], # CHECK-DAG: x_boolarr = [false, true]
152+
x_bool=True, # CHECK-DAG: x_bool = true
153+
x_dboolarr=[True, False], # CHECK-DAG: x_dboolarr = array<i1: true, false>
154+
x_df16arr=[21, 22], # CHECK-DAG: x_df16arr = array<i16: 21, 22>
155+
# CHECK-DAG: x_df32arr = array<f32: 2.300000e+01, 2.400000e+01>
156+
x_df32arr=[23, 24],
157+
# CHECK-DAG: x_df64arr = array<f64: 2.500000e+01, 2.600000e+01>
158+
x_df64arr=[25, 26],
159+
x_di32arr=[0, 1], # CHECK-DAG: x_di32arr = array<i32: 0, 1>
160+
# CHECK-DAG: x_di64arr = array<i64: 1, 2>
161+
x_di64arr=[1, 2],
162+
x_di8arr=[2, 3], # CHECK-DAG: x_di8arr = array<i8: 2, 3>
163+
# CHECK-DAG: x_dictarr = [{a = false}]
164+
x_dictarr=[{"a": BoolAttr.get(False)}],
165+
x_dict={"b": BoolAttr.get(True)}, # CHECK-DAG: x_dict = {b = true}
166+
x_f32=-2.25, # CHECK-DAG: x_f32 = -2.250000e+00 : f32
167+
# CHECK-DAG: x_f32arr = [2.000000e+00 : f32, 3.000000e+00 : f32]
168+
x_f32arr=[2.0, 3.0],
169+
x_f64=4.25, # CHECK-DAG: x_f64 = 4.250000e+00 : f64
170+
x_f64arr=[4.0, 8.0], # CHECK-DAG: x_f64arr = [4.000000e+00, 8.000000e+00]
171+
# CHECK-DAG: x_f64elems = dense<[3.952530e-323, 7.905050e-323]> : tensor<2xf64>
172+
x_f64elems=[8.0, 16.0],
173+
# CHECK-DAG: x_flatsymrefarr = [@symbol1, @symbol2]
174+
x_flatsymrefarr=["symbol1", "symbol2"],
175+
x_flatsymref="symbol3", # CHECK-DAG: x_flatsymref = @symbol3
176+
x_i1=0, # CHECK-DAG: x_i1 = false
177+
x_i16=42, # CHECK-DAG: x_i16 = 42 : i16
178+
x_i32=6, # CHECK-DAG: x_i32 = 6 : i32
179+
x_i32arr=[4, 5], # CHECK-DAG: x_i32arr = [4 : i32, 5 : i32]
180+
x_i32elems=[5, 6], # CHECK-DAG: x_i32elems = dense<[5, 6]> : tensor<2xsi32>
181+
x_i64=9, # CHECK-DAG: x_i64 = 9 : i64
182+
x_i64arr=[7, 8], # CHECK-DAG: x_i64arr = [7, 8]
183+
x_i64elems=[8, 9], # CHECK-DAG: x_i64elems = dense<[8, 9]> : tensor<2xsi64>
184+
x_i64svecarr=[10, 11], # CHECK-DAG: x_i64svecarr = [10, 11]
185+
x_i8=11, # CHECK-DAG: x_i8 = 11 : i8
186+
x_idx=10, # CHECK-DAG: x_idx = 10 : index
187+
# CHECK-DAG: x_idxelems = dense<[11, 12]> : tensor<2xindex>
188+
x_idxelems=[11, 12],
189+
# CHECK-DAG: x_idxlistarr = [{{\[}}13], [14, 15]]
190+
x_idxlistarr=[[13], [14, 15]],
191+
x_si1=-1, # CHECK-DAG: x_si1 = -1 : si1
192+
x_si16=-2, # CHECK-DAG: x_si16 = -2 : si16
193+
x_si32=-3, # CHECK-DAG: x_si32 = -3 : si32
194+
x_si64=-123, # CHECK-DAG: x_si64 = -123 : si64
195+
x_si8=-4, # CHECK-DAG: x_si8 = -4 : si8
196+
x_strarr=["hello", "world"], # CHECK-DAG: x_strarr = ["hello", "world"]
197+
x_str="hello world!", # CHECK-DAG: x_str = "hello world!"
198+
# CHECK-DAG: x_symrefarr = [@flatsym, @deep::@sym]
199+
x_symrefarr=["flatsym", ["deep", "sym"]],
200+
x_symref=["deep", "sym2"], # CHECK-DAG: x_symref = @deep::@sym2
201+
x_sym="symbol", # CHECK-DAG: x_sym = "symbol"
202+
x_typearr=[F32Type.get()], # CHECK-DAG: x_typearr = [f32]
203+
x_type=F64Type.get(), # CHECK-DAG: x_type = f64
204+
x_ui1=1, # CHECK-DAG: x_ui1 = 1 : ui1
205+
x_ui16=2, # CHECK-DAG: x_ui16 = 2 : ui16
206+
x_ui32=3, # CHECK-DAG: x_ui32 = 3 : ui32
207+
x_ui64=4, # CHECK-DAG: x_ui64 = 4 : ui64
208+
x_ui8=5, # CHECK-DAG: x_ui8 = 5 : ui8
209+
x_unit=True, # CHECK-DAG: x_unit
158210
)
159-
print(op)
211+
op.verify()
212+
op.print(use_local_scope=True)
160213

161214

162215
# CHECK-LABEL: TEST: inferReturnTypes
@@ -247,7 +300,6 @@ def testOptionalOperandOp():
247300

248301
module = Module.create()
249302
with InsertionPoint(module.body):
250-
251303
op1 = test.OptionalOperandOp()
252304
# CHECK: op1.input is None: True
253305
print(f"op1.input is None: {op1.input is None}")

0 commit comments

Comments
 (0)