Skip to content

Commit 37fe3c6

Browse files
authored
[mlir][python] Fix generation of Python bindings for async dialect (#75960)
The Python bindings generated for "async" dialect didn't include any of the "async" dialect ops. This PR fixes issues with generation of Python bindings for "async" dialect and adds a test case to use them.
1 parent 34acbb3 commit 37fe3c6

File tree

5 files changed

+66
-4
lines changed

5 files changed

+66
-4
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS AsyncOps.td)
12
add_mlir_dialect(AsyncOps async)
23
add_mlir_doc(AsyncOps AsyncDialect Dialects/ -gen-dialect-doc)

mlir/python/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ declare_mlir_dialect_python_bindings(
7979
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
8080
TD_FILE dialects/AsyncOps.td
8181
SOURCES_GLOB dialects/async_dialect/*.py
82-
DIALECT_NAME async_dialect)
82+
DIALECT_NAME async)
8383

8484
declare_mlir_dialect_python_bindings(
8585
ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -591,7 +591,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
591591

592592
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
593593
MODULE_NAME _mlirAsyncPasses
594-
ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect
594+
ADD_TO_PARENT MLIRPythonSources.Dialects.async
595595
ROOT_DIR "${PYTHON_SOURCE_DIR}"
596596
SOURCES
597597
AsyncPasses.cpp

mlir/python/mlir/dialects/async_dialect/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
from .._async_dialect_ops_gen import *
5+
from .._async_ops_gen import *

mlir/test/python/dialects/async_dialect.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# RUN: %PYTHON %s | FileCheck %s
22

33
from mlir.ir import *
4-
import mlir.dialects.async_dialect
4+
from mlir.dialects import arith
5+
import mlir.dialects.async_dialect as async_dialect
56
import mlir.dialects.async_dialect.passes
67
from mlir.passmanager import *
78

@@ -11,6 +12,19 @@ def run(f):
1112
f()
1213

1314

15+
# CHECK-LABEL: TEST: testCreateGroupOp
16+
@run
17+
def testCreateGroupOp():
18+
with Context() as ctx, Location.unknown():
19+
module = Module.create()
20+
with InsertionPoint(module.body):
21+
i32 = IntegerType.get_signless(32)
22+
group_size = arith.ConstantOp(i32, 4)
23+
async_dialect.create_group(group_size)
24+
# CHECK: %0 = "arith.constant"() <{value = 4 : i32}> : () -> i32
25+
# CHECK: %1 = "async.create_group"(%0) : (i32) -> !async.group
26+
print(module)
27+
1428
def testAsyncPass():
1529
with Context() as context:
1630
PassManager.parse("any(async-to-async-runtime)")

utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,53 @@ filegroup(
331331
],
332332
)
333333

334+
##---------------------------------------------------------------------------##
335+
# Async dialect.
336+
##---------------------------------------------------------------------------##
337+
338+
gentbl_filegroup(
339+
name = "AsyncOpsPyGen",
340+
tbl_outs = [
341+
(
342+
[
343+
"-gen-python-enum-bindings",
344+
"-bind-dialect=async",
345+
],
346+
"mlir/dialects/_async_enum_gen.py",
347+
),
348+
(
349+
[
350+
"-gen-python-op-bindings",
351+
"-bind-dialect=async",
352+
],
353+
"mlir/dialects/_async_ops_gen.py",
354+
),
355+
],
356+
tblgen = "//mlir:mlir-tblgen",
357+
td_file = "mlir/dialects/AsyncOps.td",
358+
deps = [
359+
"//mlir:AsyncOpsTdFiles",
360+
"//mlir:OpBaseTdFiles",
361+
],
362+
)
363+
364+
filegroup(
365+
name = "AsyncOpsPyFiles",
366+
srcs = [
367+
":AsyncOpsPyGen",
368+
],
369+
)
370+
371+
filegroup(
372+
name = "AsyncOpsPackagePyFiles",
373+
srcs = glob(["mlir/dialects/async_dialect/*.py"]),
374+
)
375+
376+
filegroup(
377+
name = "AsyncOpsPackagePassesPyFiles",
378+
srcs = glob(["mlir/dialects/async_dialect/passes/*.py"]),
379+
)
380+
334381
##---------------------------------------------------------------------------##
335382
# Arith dialect.
336383
##---------------------------------------------------------------------------##

0 commit comments

Comments
 (0)