Skip to content

Commit d898ff6

Browse files
authored
[mlir,python] Fix case when FuncOp.arg_attrs is not set (#117188)
FuncOps can have `arg_attrs`, an array of dictionary attributes associated with their arguments. E.g., ```mlir func.func @main(%arg0: tensor<8xf32> {test.attr_name = "value"}, %arg1: tensor<8x16xf32>) ``` These are exposed via the MLIR Python bindings with `my_funcop.arg_attrs`. In this case, it would return `[{test.attr_name = "value"}, {}]`, i.e., `%arg1` has an empty `DictAttr`. However, if I try and access this property from a FuncOp with an empty `arg_attrs`, e.g., ```mlir func.func @main(%arg0: tensor<8xf32>, %arg1: tensor<8x16xf32>) ``` This raises the error: ```python return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^ KeyError: 'attempt to access a non-existent attribute' ``` This PR fixes this by returning the expected `[{}, {}]`.
1 parent 5827542 commit d898ff6

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

mlir/python/mlir/dialects/func.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
105105

106106
@property
107107
def arg_attrs(self):
108+
if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
109+
return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
108110
return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
109111

110112
@arg_attrs.setter

mlir/test/python/dialects/func.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,32 @@ def testFunctionCalls():
104104
# CHECK: %1 = call @qux() : () -> f32
105105
# CHECK: return
106106
# CHECK: }
107+
108+
109+
# CHECK-LABEL: TEST: testFunctionArgAttrs
110+
@constructAndPrintInModule
111+
def testFunctionArgAttrs():
112+
foo = func.FuncOp("foo", ([F32Type.get()], []))
113+
foo.sym_visibility = StringAttr.get("private")
114+
foo2 = func.FuncOp("foo2", ([F32Type.get(), F32Type.get()], []))
115+
foo2.sym_visibility = StringAttr.get("private")
116+
117+
empty_attr = DictAttr.get({})
118+
test_attr = DictAttr.get({"test.foo": StringAttr.get("bar")})
119+
test_attr2 = DictAttr.get({"test.baz": StringAttr.get("qux")})
120+
121+
assert len(foo.arg_attrs) == 1
122+
assert foo.arg_attrs[0] == empty_attr
123+
124+
foo.arg_attrs = [test_attr]
125+
assert foo.arg_attrs[0]["test.foo"] == StringAttr.get("bar")
126+
127+
assert len(foo2.arg_attrs) == 2
128+
assert foo2.arg_attrs == ArrayAttr.get([empty_attr, empty_attr])
129+
130+
foo2.arg_attrs = [empty_attr, test_attr2]
131+
assert foo2.arg_attrs == ArrayAttr.get([empty_attr, test_attr2])
132+
133+
134+
# CHECK: func private @foo(f32 {test.foo = "bar"})
135+
# CHECK: func private @foo2(f32, f32 {test.baz = "qux"})

0 commit comments

Comments
 (0)