Skip to content

Commit bc30b41

Browse files
[mlir] enable python bindings for nvgpu transforms (#68088)
Expose the autogenerated bindings. Co-authored-by: Martin Lücke <[email protected]>
1 parent 3e3cf77 commit bc30b41

File tree

5 files changed

+81
-0
lines changed

5 files changed

+81
-0
lines changed

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,15 @@ declare_mlir_dialect_extension_python_bindings(
200200
DIALECT_NAME transform
201201
EXTENSION_NAME memref_transform)
202202

203+
declare_mlir_dialect_extension_python_bindings(
204+
ADD_TO_PARENT MLIRPythonSources.Dialects
205+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
206+
TD_FILE dialects/NVGPUTransformOps.td
207+
SOURCES
208+
dialects/transform/nvgpu.py
209+
DIALECT_NAME transform
210+
EXTENSION_NAME nvgpu_transform)
211+
203212
declare_mlir_dialect_extension_python_bindings(
204213
ADD_TO_PARENT MLIRPythonSources.Dialects
205214
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===-- NVGPUTransformOps.td -------------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Entry point of the Python bindings generator for the transform ops provided
10+
// by the NVGPU dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
15+
#ifndef PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS
16+
#define PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS
17+
18+
include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td"
19+
20+
#endif // PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from .._nvgpu_transform_ops_gen import *
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import transform
5+
from mlir.dialects.transform import nvgpu
6+
7+
8+
def run(f):
9+
with Context(), Location.unknown():
10+
module = Module.create()
11+
with InsertionPoint(module.body):
12+
print("\nTEST:", f.__name__)
13+
f()
14+
print(module)
15+
return f
16+
17+
18+
@run
19+
def testCreateAsyncGroups():
20+
sequence = transform.SequenceOp(
21+
transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
22+
)
23+
with InsertionPoint(sequence.body):
24+
nvgpu.CreateAsyncGroupsOp(transform.AnyOpType.get(), sequence.bodyTarget)
25+
transform.YieldOp()
26+
# CHECK-LABEL: TEST: testCreateAsyncGroups
27+
# CHECK: transform.nvgpu.create_async_groups

utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,25 @@ gentbl_filegroup(
12091209
],
12101210
)
12111211

1212+
gentbl_filegroup(
1213+
name = "NVGPUTransformOpsPyGen",
1214+
tbl_outs = [
1215+
(
1216+
[
1217+
"-gen-python-op-bindings",
1218+
"-bind-dialect=transform",
1219+
"-dialect-extension=nvgpu_transform",
1220+
],
1221+
"mlir/dialects/_nvgpu_transform_ops_gen.py",
1222+
),
1223+
],
1224+
tblgen = "//mlir:mlir-tblgen",
1225+
td_file = "mlir/dialects/NVGPUTransformOps.td",
1226+
deps = [
1227+
"//mlir:NVGPUTransformOpsTdFiles",
1228+
],
1229+
)
1230+
12121231
gentbl_filegroup(
12131232
name = "PDLTransformOpsPyGen",
12141233
tbl_outs = [
@@ -1327,6 +1346,7 @@ filegroup(
13271346
":GPUTransformOpsPyGen",
13281347
":LoopTransformOpsPyGen",
13291348
":MemRefTransformOpsPyGen",
1349+
":NVGPUTransformOpsPyGen",
13301350
":PDLTransformOpsPyGen",
13311351
":SparseTensorTransformOpsPyGen",
13321352
":StructureTransformEnumPyGen",

0 commit comments

Comments
 (0)