Skip to content

Commit 2669dd6

Browse files
committed
enable binding
1 parent 20f62f0 commit 2669dd6

File tree

16 files changed

+126
-9
lines changed

16 files changed

+126
-9
lines changed

include/gc-c/Dialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ extern "C" {
2727
#endif
2828

2929
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(OneDNNGraph, onednn_graph);
30+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(CPURuntime, cpuruntime);
3031

3132
#ifdef __cplusplus
3233
}

include/gc-c/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
extern "C" {
2525
#endif
2626

27+
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.h.inc"
2728
#include "gc/Transforms/Passes.capi.h.inc"
2829
#ifdef __cplusplus
2930
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
add_mlir_dialect(OneDNNGraphOps onednn_graph)
2-
# add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc/Dialect/OneDNNGraph/ -gen-op-doc)
3-
# add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc/Dialect/OneDNNGraph/ -gen-dialect-doc)
2+
add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc/Dialect/OneDNNGraph/ -gen-op-doc)
3+
add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc/Dialect/OneDNNGraph/ -gen-dialect-doc)

lib/gc/CAPI/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,7 @@ add_mlir_public_c_api_library(GcCAPI
33
Passes.cpp
44
LINK_LIBS PUBLIC
55
MLIROneDNNGraph
6+
MLIRCPURuntimeDialect
67
GCPasses
8+
MLIRCPURuntimeTransforms
79
)

lib/gc/CAPI/Dialects.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
#include "gc-c/Dialects.h"
1010
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
11+
#include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h"
1112
#include "mlir/CAPI/Registration.h"
1213

1314
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(OneDNNGraph, onednn_graph,
14-
mlir::onednn_graph::OneDNNGraphDialect)
15+
mlir::onednn_graph::OneDNNGraphDialect)
16+
17+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(CPURuntime, cpuruntime,
18+
mlir::cpuruntime::CPURuntimeDialect)

lib/gc/CAPI/Passes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,20 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "gc/Transforms/Passes.h"
10+
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h"
1011
#include "mlir-c/Pass.h"
1112
#include "mlir/CAPI/Pass.h"
1213

14+
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.h.inc"
1315
#include "gc/Transforms/Passes.capi.h.inc"
1416
using namespace mlir::gc;
17+
using namespace mlir::cpuruntime;
1518

1619
#ifdef __cplusplus
1720
extern "C" {
1821
#endif
1922

23+
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.cpp.inc"
2024
#include "gc/Transforms/Passes.capi.cpp.inc"
2125

2226
#ifdef __cplusplus

python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ declare_mlir_dialect_python_bindings(
5252
DIALECT_NAME onednn_graph
5353
)
5454

55+
declare_mlir_dialect_python_bindings(
56+
ADD_TO_PARENT GcPythonSources
57+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/gc_mlir"
58+
TD_FILE dialects/CPURuntimeOps.td
59+
SOURCES
60+
dialects/cpuruntime.py
61+
DIALECT_NAME cpuruntime
62+
)
63+
5564
declare_mlir_python_extension(GcPythonSources.Extension
5665
MODULE_NAME _gc_mlir
5766
ADD_TO_PARENT GcPythonSources

python/MainModule.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,20 @@ PYBIND11_MODULE(_gc_mlir, m) {
4040
}
4141
},
4242
py::arg("context") = py::none(), py::arg("load") = true);
43+
44+
//===----------------------------------------------------------------------===//
45+
// CPURuntime
46+
//===----------------------------------------------------------------------===//
47+
mlirRegisterCPURuntimePasses();
48+
auto cpuruntimeM = m.def_submodule("cpuruntime");
49+
cpuruntimeM.def(
50+
"register_dialect",
51+
[](MlirContext context, bool load) {
52+
MlirDialectHandle dialect = mlirGetDialectHandle__cpuruntime__();
53+
mlirDialectHandleRegisterDialect(dialect, context);
54+
if (load) {
55+
mlirDialectHandleLoadDialect(dialect, context);
56+
}
57+
},
58+
py::arg("context") = py::none(), py::arg("load") = true);
4359
}

python/gc_mlir/_mlir_libs/_site_initialize_0.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55

66
def context_init_hook(context):
7-
from ._gc_mlir.onednn_graph import register_dialect
8-
9-
register_dialect(context)
7+
from ._gc_mlir.onednn_graph import register_dialect as register_onednn_graph_dialect
8+
from ._gc_mlir.cpuruntime import register_dialect as register_cpuruntime_dialect
9+
10+
register_onednn_graph_dialect(context)
11+
register_cpuruntime_dialect(context)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#ifndef PYTHON_BINDINGS_CPURUNTIME_OPS
2+
#define PYTHON_BINDINGS_CPURUNTIME_OPS
3+
4+
include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.td"
5+
6+
#endif // PYTHON_BINDINGS_CPURUNTIME_OPS

python/gc_mlir/dialects/OneDNNGraphOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111

1212
include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.td"
1313

14-
#endif
14+
#endif // PYTHON_BINDINGS_ONEDNNGRAPH_OPS

python/gc_mlir/dialects/cpuruntime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from ._cpuruntime_ops_gen import *
2+
from .._mlir_libs._gc_mlir.cpuruntime import *

test/gc/python/dialects/cpuruntime.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# See the License for the specific language governing permissions and
2+
# limitations under the License.
3+
# ===============================================================================
4+
# RUN: %python %s | FileCheck %s
5+
6+
from gc_mlir.ir import *
7+
from gc_mlir.dialects import cpuruntime, func
8+
9+
10+
def run(f):
11+
print("\nTEST:", f.__name__)
12+
f()
13+
return f
14+
15+
16+
# CHECK-LABEL: TEST: testCPURuntimeOps
17+
@run
18+
def testCPURuntimeOps():
19+
with Context() as ctx, Location.unknown():
20+
module = Module.create()
21+
with InsertionPoint(module.body):
22+
23+
@func.FuncOp.from_py_func(F32Type.get(), IntegerType.get_signless(32))
24+
def do_print(arg1, arg2):
25+
cpuruntime.printf("Hello world %f %d", [arg1, arg2])
26+
return
27+
28+
# CHECK-LABEL: func @do_print(
29+
# CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: i32) {
30+
# CHECK: cpuruntime.printf "Hello world %f %d" %[[ARG_0]], %[[ARG_1:.*]] : f32, i32
31+
# CHECK: return
32+
# CHECK: }
33+
print(module)
34+
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# See the License for the specific language governing permissions and
2+
# limitations under the License.
3+
# ===============================================================================
4+
# RUN: %python %s | FileCheck %s
5+
6+
from gc_mlir.ir import *
7+
from gc_mlir.dialects import onednn_graph, func
8+
9+
10+
def run(f):
11+
print("\nTEST:", f.__name__)
12+
f()
13+
return f
14+
15+
16+
# CHECK-LABEL: TEST: testOneDNNGraphOps
17+
@run
18+
def testOneDNNGraphOps():
19+
with Context() as ctx, Location.unknown():
20+
module = Module.create()
21+
with InsertionPoint(module.body):
22+
f32 = F32Type.get(ctx)
23+
tensor_type = RankedTensorType.get([128, 128], f32)
24+
25+
@func.FuncOp.from_py_func(tensor_type, tensor_type)
26+
def entry(arg1, arg2):
27+
res1 = onednn_graph.matmul(
28+
arg1, arg2, bias=None, transpose_a=False, transpose_b=False
29+
)
30+
res2 = onednn_graph.add(res1, arg2)
31+
return onednn_graph.relu(res2)
32+
33+
# CHECK: [[MM:%.+]] = onednn_graph.matmul
34+
# CHECK: [[ADD:%.+]] = onednn_graph.add
35+
# CHECK: [[RELU:%.+]] = onednn_graph.relu
36+
print(module)
File renamed without changes.

test/gc/Python/smoketest.py renamed to test/gc/python/smoketest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ def run(f):
1414
return f
1515

1616

17-
# CHECK-LABEL: TEST: testCreatetOp
17+
# CHECK-LABEL: TEST: testCreateOp
1818
# CHECK onednn_graph.add
1919
@run
20-
def testCreatetOp():
20+
def testCreateOp():
2121
with Context() as ctx, Location.unknown():
2222
module = Module.create()
2323
f32 = F32Type.get(ctx)

0 commit comments

Comments
 (0)