Skip to content

Commit 1f8618f

Browse files
committed
[mlir] python enum bindings generator
Add an ODS (tablegen) backend to generate Python enum classes and attribute builders for enum attributes defined in ODS. This will allow us to keep the enum attribute definitions in sync between C++ and Python, as opposed to handwritten enum classes in Python that may end up using mismatching values. This also makes autogenerated bindings more convenient even in absence of mixins. Use this backend for the transform dialect failure propagation mode enum attribute as demonstration. Reviewed By: ingomueller-net Differential Revision: https://reviews.llvm.org/D156553
1 parent 235390d commit 1f8618f

File tree

7 files changed

+272
-79
lines changed

7 files changed

+272
-79
lines changed

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,15 @@ declare_mlir_dialect_python_bindings(
134134
_mlir_libs/_mlir/dialects/transform/__init__.pyi
135135
DIALECT_NAME transform)
136136

137+
set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td")
138+
mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings)
139+
add_public_tablegen_target(MLIRTransformDialectPyEnumGen)
140+
declare_mlir_python_sources(
141+
MLIRPythonSources.Dialects.transform.enum_gen
142+
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
143+
ADD_TO_PARENT MLIRPythonSources.Dialects.transform
144+
SOURCES "dialects/_transform_enum_gen.py")
145+
137146
declare_mlir_dialect_extension_python_bindings(
138147
ADD_TO_PARENT MLIRPythonSources.Dialects
139148
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

mlir/python/mlir/dialects/_transform_ops_ext.py

Lines changed: 54 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,68 +15,66 @@
1515

1616

1717
class CastOp:
18-
19-
def __init__(
20-
self,
21-
result_type: Type,
22-
target: Union[Operation, Value],
23-
*,
24-
loc=None,
25-
ip=None,
26-
):
27-
super().__init__(
28-
result_type, _get_op_result_or_value(target), loc=loc, ip=ip
29-
)
18+
def __init__(
19+
self,
20+
result_type: Type,
21+
target: Union[Operation, Value],
22+
*,
23+
loc=None,
24+
ip=None,
25+
):
26+
super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
3027

3128

3229
class ApplyPatternsOp:
30+
def __init__(
31+
self,
32+
target: Union[Operation, Value, OpView],
33+
*,
34+
loc=None,
35+
ip=None,
36+
):
37+
operands = []
38+
operands.append(_get_op_result_or_value(target))
39+
super().__init__(
40+
self.build_generic(
41+
attributes={},
42+
results=[],
43+
operands=operands,
44+
successors=None,
45+
regions=None,
46+
loc=loc,
47+
ip=ip,
48+
)
49+
)
50+
self.regions[0].blocks.append()
3351

34-
def __init__(
35-
self,
36-
target: Union[Operation, Value, OpView],
37-
*,
38-
loc=None,
39-
ip=None,
40-
):
41-
operands = []
42-
operands.append(_get_op_result_or_value(target))
43-
super().__init__(
44-
self.build_generic(attributes={},
45-
results=[],
46-
operands=operands,
47-
successors=None,
48-
regions=None,
49-
loc=loc,
50-
ip=ip))
51-
self.regions[0].blocks.append()
52-
53-
@property
54-
def patterns(self) -> Block:
55-
return self.regions[0].blocks[0]
52+
@property
53+
def patterns(self) -> Block:
54+
return self.regions[0].blocks[0]
5655

5756

5857
class testGetParentOp:
59-
60-
def __init__(
61-
self,
62-
result_type: Type,
63-
target: Union[Operation, Value],
64-
*,
65-
isolated_from_above: bool = False,
66-
op_name: Optional[str] = None,
67-
deduplicate: bool = False,
68-
loc=None,
69-
ip=None,
70-
):
71-
super().__init__(
72-
result_type,
73-
_get_op_result_or_value(target),
74-
isolated_from_above=isolated_from_above,
75-
op_name=op_name,
76-
deduplicate=deduplicate,
77-
loc=loc,
78-
ip=ip,
79-
)
58+
def __init__(
59+
self,
60+
result_type: Type,
61+
target: Union[Operation, Value],
62+
*,
63+
isolated_from_above: bool = False,
64+
op_name: Optional[str] = None,
65+
deduplicate: bool = False,
66+
loc=None,
67+
ip=None,
68+
):
69+
super().__init__(
70+
result_type,
71+
_get_op_result_or_value(target),
72+
isolated_from_above=isolated_from_above,
73+
op_name=op_name,
74+
deduplicate=deduplicate,
75+
loc=loc,
76+
ip=ip,
77+
)
8078

8179

8280
class MergeHandlesOp:
@@ -130,12 +128,6 @@ def __init__(
130128
else None
131129
)
132130
root_type = root.type if not isinstance(target, Type) else target
133-
if not isinstance(failure_propagation_mode, Attribute):
134-
failure_propagation_mode_attr = IntegerAttr.get(
135-
IntegerType.get_signless(32), failure_propagation_mode._as_int()
136-
)
137-
else:
138-
failure_propagation_mode_attr = failure_propagation_mode
139131

140132
if extra_bindings is None:
141133
extra_bindings = []
@@ -152,7 +144,7 @@ def __init__(
152144

153145
super().__init__(
154146
results_=results,
155-
failure_propagation_mode=failure_propagation_mode_attr,
147+
failure_propagation_mode=failure_propagation_mode,
156148
root=root,
157149
extra_bindings=extra_bindings,
158150
)

mlir/python/mlir/dialects/transform/__init__.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,6 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
from enum import Enum
6-
7-
8-
class FailurePropagationMode(Enum):
9-
"""Propagation mode for silenceable errors."""
10-
11-
PROPAGATE = 1
12-
SUPPRESS = 2
13-
14-
def _as_int(self):
15-
if self is FailurePropagationMode.PROPAGATE:
16-
return 1
17-
18-
assert self is FailurePropagationMode.SUPPRESS
19-
return 2
20-
21-
5+
from .._transform_enum_gen import *
226
from .._transform_ops_gen import *
237
from ..._mlir_libs._mlirDialectsTransform import *
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: mlir-tblgen -gen-python-enum-bindings %s -I %S/../../include | FileCheck %s
2+
3+
include "mlir/IR/EnumAttr.td"
4+
5+
// CHECK: Autogenerated by mlir-tblgen; don't manually edit.
6+
7+
// CHECK: from enum import Enum
8+
// CHECK: from ._ods_common import _cext as _ods_cext
9+
// CHECK: _ods_ir = _ods_cext.ir
10+
11+
def One : I32EnumAttrCase<"CaseOne", 1, "one">;
12+
def Two : I32EnumAttrCase<"CaseTwo", 2, "two">;
13+
14+
def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two]>;
15+
// CHECK: def _register_attribute_builder(kind):
16+
// CHECK: def decorator_builder(func):
17+
// CHECK: _ods_ir.AttrBuilder.insert(kind, func)
18+
// CHECK: return func
19+
// CHECK: return decorator_builder
20+
21+
// CHECK-LABEL: class MyEnum(Enum):
22+
// CHECK: """An example 32-bit enum"""
23+
24+
// CHECK: CASE_ONE = 1
25+
// CHECK: CASE_TWO = 2
26+
27+
// CHECK: def _as_int(self):
28+
// CHECK: if self is MyEnum.CASE_ONE:
29+
// CHECK: return 1
30+
// CHECK: if self is MyEnum.CASE_TWO:
31+
// CHECK: return 2
32+
// CHECK: assert False, "Unknown MyEnum enum entry."
33+
34+
def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">;
35+
def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">;
36+
37+
def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>;
38+
// CHECK: @_register_attribute_builder("MyEnum")
39+
// CHECK: def _my_enum(x, context):
40+
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), x._as_int())
41+
42+
// CHECK-LABEL: class MyEnum64(Enum):
43+
// CHECK: """An example 64-bit enum"""
44+
45+
// CHECK: CASE_ONE64 = 1
46+
// CHECK: CASE_TWO64 = 2
47+
48+
// CHECK: def _as_int(self):
49+
// CHECK: if self is MyEnum64.CASE_ONE64:
50+
// CHECK: return 1
51+
// CHECK: if self is MyEnum64.CASE_TWO64:
52+
// CHECK: return 2
53+
// CHECK: assert False, "Unknown MyEnum64 enum entry."
54+
55+
// CHECK: @_register_attribute_builder("MyEnum64")
56+
// CHECK: def _my_enum64(x, context):
57+
// CHECK: return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), x._as_int())

mlir/tools/mlir-tblgen/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ add_tablegen(mlir-tblgen MLIR
1414
DialectGen.cpp
1515
DirectiveCommonGen.cpp
1616
EnumsGen.cpp
17+
EnumPythonBindingGen.cpp
1718
FormatGen.cpp
1819
LLVMIRConversionGen.cpp
1920
LLVMIRIntrinsicGen.cpp
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
10+
// generate the corresponding Python binding classes.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/TableGen/Attribute.h"
15+
#include "mlir/TableGen/GenInfo.h"
16+
#include "llvm/ADT/StringExtras.h"
17+
#include "llvm/Support/FormatVariadic.h"
18+
#include "llvm/TableGen/Record.h"
19+
20+
using namespace mlir;
21+
using namespace mlir::tblgen;
22+
23+
/// File header and includes.
24+
constexpr const char *fileHeader = R"Py(
25+
# Autogenerated by mlir-tblgen; don't manually edit.
26+
27+
from enum import Enum
28+
from ._ods_common import _cext as _ods_cext
29+
_ods_ir = _ods_cext.ir
30+
31+
# Convenience decorator for registering user-friendly Attribute builders.
32+
def _register_attribute_builder(kind):
33+
def decorator_builder(func):
34+
_ods_ir.AttrBuilder.insert(kind, func)
35+
return func
36+
37+
return decorator_builder
38+
39+
)Py";
40+
41+
/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
42+
static std::string makePythonEnumCaseName(StringRef name) {
43+
return StringRef(llvm::convertToSnakeFromCamelCase(name)).upper();
44+
}
45+
46+
/// Emits the Python class for the given enum.
47+
static void emitEnumClass(StringRef enumName, StringRef description,
48+
ArrayRef<EnumAttrCase> cases, raw_ostream &os) {
49+
os << llvm::formatv("class {0}(Enum):\n", enumName);
50+
if (!description.empty())
51+
os << llvm::formatv(" \"\"\"{0}\"\"\"\n", description);
52+
os << "\n";
53+
54+
for (const EnumAttrCase &enumCase : cases) {
55+
os << llvm::formatv(" {0} = {1}\n",
56+
makePythonEnumCaseName(enumCase.getSymbol()),
57+
enumCase.getValue());
58+
}
59+
60+
os << "\n";
61+
os << llvm::formatv(" def _as_int(self):\n");
62+
for (const EnumAttrCase &enumCase : cases) {
63+
os << llvm::formatv(" if self is {0}.{1}:\n", enumName,
64+
makePythonEnumCaseName(enumCase.getSymbol()));
65+
os << llvm::formatv(" return {0}\n", enumCase.getValue());
66+
}
67+
os << llvm::formatv(" assert False, \"Unknown {0} enum entry.\"\n\n\n",
68+
enumName);
69+
}
70+
71+
/// Attempts to extract the bitwidth B from string "uintB_t" describing the
72+
/// type. This bitwidth information is not readily available in ODS. Returns
73+
/// `false` on success, `true` on failure.
74+
static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
75+
if (!uintType.consume_front("uint"))
76+
return true;
77+
if (!uintType.consume_back("_t"))
78+
return true;
79+
return uintType.getAsInteger(/*Radix=*/10, bitwidth);
80+
}
81+
82+
/// Emits an attribute builder for the given enum attribute to support automatic
83+
/// conversion between enum values and attributes in Python. Returns
84+
/// `false` on success, `true` on failure.
85+
static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
86+
int64_t bitwidth;
87+
if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
88+
llvm::errs() << "failed to identify bitwidth of "
89+
<< enumAttr.getUnderlyingType();
90+
return true;
91+
}
92+
93+
os << llvm::formatv("@_register_attribute_builder(\"{0}\")\n",
94+
enumAttr.getAttrDefName());
95+
os << llvm::formatv(
96+
"def _{0}(x, context):\n",
97+
llvm::convertToSnakeFromCamelCase(enumAttr.getAttrDefName()));
98+
os << llvm::formatv(
99+
" return "
100+
"_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
101+
"context=context), x._as_int())\n\n",
102+
bitwidth);
103+
return false;
104+
}
105+
106+
/// Emits Python bindings for all enums in the record keeper. Returns
107+
/// `false` on success, `true` on failure.
108+
static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
109+
raw_ostream &os) {
110+
os << fileHeader;
111+
std::vector<llvm::Record *> defs =
112+
recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
113+
for (const llvm::Record *def : defs) {
114+
EnumAttr enumAttr(*def);
115+
if (enumAttr.isBitEnum()) {
116+
llvm::errs() << "bit enums not supported\n";
117+
return true;
118+
}
119+
emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(),
120+
enumAttr.getAllCases(), os);
121+
emitAttributeBuilder(enumAttr, os);
122+
}
123+
return false;
124+
}
125+
126+
// Registers the enum utility generator to mlir-tblgen.
127+
static mlir::GenRegistration
128+
genPythonEnumBindings("gen-python-enum-bindings",
129+
"Generate Python bindings for enum attributes",
130+
&emitPythonEnums);

0 commit comments

Comments
 (0)