Skip to content

Commit 3d27d11

Browse files
authored
[mlir][sparse] Generates python bindings for SparseTensorTransformOps. (#66937)
1 parent e0aaa19 commit 3d27d11

File tree

5 files changed

+79
-0
lines changed

5 files changed

+79
-0
lines changed

mlir/python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,15 @@ declare_mlir_dialect_extension_python_bindings(
213213
"../../include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td"
214214
)
215215

216+
declare_mlir_dialect_extension_python_bindings(
217+
ADD_TO_PARENT MLIRPythonSources.Dialects
218+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
219+
TD_FILE dialects/SparseTensorTransformOps.td
220+
SOURCES
221+
dialects/transform/sparse_tensor.py
222+
DIALECT_NAME transform
223+
EXTENSION_NAME sparse_tensor_transform)
224+
216225
declare_mlir_dialect_extension_python_bindings(
217226
ADD_TO_PARENT MLIRPythonSources.Dialects
218227
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===-- SparseTensorTransformOps.td ------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_SPARSE_TENSOR_TRANSFORM_OPS
10+
#define PYTHON_BINDINGS_SPARSE_TENSOR_TRANSFORM_OPS
11+
12+
include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td"
13+
14+
#endif
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from .._sparse_tensor_transform_ops_gen import *
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import transform
5+
from mlir.dialects.transform import sparse_tensor
6+
7+
8+
def run(f):
9+
with Context(), Location.unknown():
10+
module = Module.create()
11+
with InsertionPoint(module.body):
12+
sequence = transform.SequenceOp(
13+
transform.FailurePropagationMode.Propagate,
14+
[],
15+
transform.AnyOpType.get(),
16+
)
17+
with InsertionPoint(sequence.body):
18+
f(sequence.bodyTarget)
19+
transform.YieldOp()
20+
print("\nTEST:", f.__name__)
21+
print(module)
22+
return f
23+
24+
25+
@run
26+
def testMatchSparseInOut(target):
27+
sparse_tensor.MatchSparseInOut(transform.AnyOpType.get(), target)
28+
# CHECK-LABEL: TEST: testMatchSparseInOut
29+
# CHECK: transform.sequence
30+
# CHECK-NEXT: ^{{.*}}(%[[ARG0:.*]]: !transform.any_op):
31+
# CHECK-NEXT: transform.sparse_tensor.match.sparse_inout %[[ARG0]]

utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,25 @@ gentbl_filegroup(
12321232
],
12331233
)
12341234

1235+
gentbl_filegroup(
1236+
name = "SparseTensorTransformOpsPyGen",
1237+
tbl_outs = [
1238+
(
1239+
[
1240+
"-gen-python-op-bindings",
1241+
"-bind-dialect=transform",
1242+
"-dialect-extension=sparse_tensor_transform",
1243+
],
1244+
"mlir/dialects/_sparse_tensor_transform_ops_gen.py",
1245+
),
1246+
],
1247+
tblgen = "//mlir:mlir-tblgen",
1248+
td_file = "mlir/dialects/SparseTensorTransformOps.td",
1249+
deps = [
1250+
"//mlir:SparseTensorTransformOpsTdFiles",
1251+
],
1252+
)
1253+
12351254
gentbl_filegroup(
12361255
name = "TensorTransformOpsPyGen",
12371256
tbl_outs = [
@@ -1309,6 +1328,7 @@ filegroup(
13091328
":LoopTransformOpsPyGen",
13101329
":MemRefTransformOpsPyGen",
13111330
":PDLTransformOpsPyGen",
1331+
":SparseTensorTransformOpsPyGen",
13121332
":StructureTransformEnumPyGen",
13131333
":StructuredTransformOpsPyGen",
13141334
":TensorTransformOpsPyGen",

0 commit comments

Comments
 (0)