Skip to content

Commit a1858db

Browse files
authored
[Binding] enable binding of cpuruntime dialect (#102)
enable binding of cpuruntime dialect
1 parent fd8c735 commit a1858db

File tree

16 files changed

+185
-9
lines changed

16 files changed

+185
-9
lines changed

include/gc-c/Dialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ extern "C" {
2727
#endif
2828

2929
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(OneDNNGraph, onednn_graph);
30+
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(CPURuntime, cpuruntime);
3031

3132
#ifdef __cplusplus
3233
}

include/gc-c/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
extern "C" {
2525
#endif
2626

27+
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.h.inc"
2728
#include "gc/Transforms/Passes.capi.h.inc"
2829
#ifdef __cplusplus
2930
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
add_mlir_dialect(OneDNNGraphOps onednn_graph)
2-
# add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc/Dialect/OneDNNGraph/ -gen-op-doc)
3-
# add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc/Dialect/OneDNNGraph/ -gen-dialect-doc)
2+
add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc/Dialect/OneDNNGraph/ -gen-op-doc)
3+
add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc/Dialect/OneDNNGraph/ -gen-dialect-doc)

lib/gc/CAPI/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,7 @@ add_mlir_public_c_api_library(GcCAPI
33
Passes.cpp
44
LINK_LIBS PUBLIC
55
MLIROneDNNGraph
6+
MLIRCPURuntimeDialect
67
GCPasses
8+
MLIRCPURuntimeTransforms
79
)

lib/gc/CAPI/Dialects.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "gc-c/Dialects.h"
10+
#include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h"
1011
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
1112
#include "mlir/CAPI/Registration.h"
1213

1314
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(OneDNNGraph, onednn_graph,
14-
mlir::onednn_graph::OneDNNGraphDialect)
15+
mlir::onednn_graph::OneDNNGraphDialect)
16+
17+
MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(CPURuntime, cpuruntime,
18+
mlir::cpuruntime::CPURuntimeDialect)

lib/gc/CAPI/Passes.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,20 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "gc/Transforms/Passes.h"
10+
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h"
1011
#include "mlir-c/Pass.h"
1112
#include "mlir/CAPI/Pass.h"
1213

14+
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.h.inc"
1315
#include "gc/Transforms/Passes.capi.h.inc"
1416
using namespace mlir::gc;
17+
using namespace mlir::cpuruntime;
1518

1619
#ifdef __cplusplus
1720
extern "C" {
1821
#endif
1922

23+
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.cpp.inc"
2024
#include "gc/Transforms/Passes.capi.cpp.inc"
2125

2226
#ifdef __cplusplus

python/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,15 @@ declare_mlir_dialect_python_bindings(
5252
DIALECT_NAME onednn_graph
5353
)
5454

55+
declare_mlir_dialect_python_bindings(
56+
ADD_TO_PARENT GcPythonSources
57+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/gc_mlir"
58+
TD_FILE dialects/CPURuntimeOps.td
59+
SOURCES
60+
dialects/cpuruntime.py
61+
DIALECT_NAME cpuruntime
62+
)
63+
5564
declare_mlir_python_extension(GcPythonSources.Extension
5665
MODULE_NAME _gc_mlir
5766
ADD_TO_PARENT GcPythonSources

python/MainModule.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,20 @@ PYBIND11_MODULE(_gc_mlir, m) {
4040
}
4141
},
4242
py::arg("context") = py::none(), py::arg("load") = true);
43+
44+
//===----------------------------------------------------------------------===//
45+
// CPURuntime
46+
//===----------------------------------------------------------------------===//
47+
mlirRegisterCPURuntimePasses();
48+
auto cpuruntimeM = m.def_submodule("cpuruntime");
49+
cpuruntimeM.def(
50+
"register_dialect",
51+
[](MlirContext context, bool load) {
52+
MlirDialectHandle dialect = mlirGetDialectHandle__cpuruntime__();
53+
mlirDialectHandleRegisterDialect(dialect, context);
54+
if (load) {
55+
mlirDialectHandleLoadDialect(dialect, context);
56+
}
57+
},
58+
py::arg("context") = py::none(), py::arg("load") = true);
4359
}

python/gc_mlir/_mlir_libs/_site_initialize_0.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99

1010
def context_init_hook(context):
11-
from ._gc_mlir.onednn_graph import register_dialect
12-
13-
register_dialect(context)
11+
from ._gc_mlir.onednn_graph import register_dialect as register_onednn_graph_dialect
12+
from ._gc_mlir.cpuruntime import register_dialect as register_cpuruntime_dialect
13+
14+
register_onednn_graph_dialect(context)
15+
register_cpuruntime_dialect(context)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===-- CPURuntimeOps.td - Entry point for bindings --------*- tablegen -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_CPURUNTIME_OPS
10+
#define PYTHON_BINDINGS_CPURUNTIME_OPS
11+
12+
include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.td"
13+
14+
#endif // PYTHON_BINDINGS_CPURUNTIME_OPS

python/gc_mlir/dialects/OneDNNGraphOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@
1111

1212
include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.td"
1313

14-
#endif
14+
#endif // PYTHON_BINDINGS_ONEDNNGRAPH_OPS

python/gc_mlir/dialects/cpuruntime.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# ===-- cpuruntime.py - MLIR Python source -------------------*- Python -*-===#
2+
#
3+
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
#
7+
# ===-----------------------------------------------------------------------===#
8+
9+
from ._cpuruntime_ops_gen import *
10+
from .._mlir_libs._gc_mlir.cpuruntime import *

test/gc/python/dialects/cpuruntime.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# See the License for the specific language governing permissions and
2+
# limitations under the License.
3+
# ===============================================================================
4+
# RUN: %python %s | FileCheck %s
5+
6+
from gc_mlir.dialects import cpuruntime, func
7+
from gc_mlir.ir import *
8+
from gc_mlir.passmanager import PassManager
9+
10+
11+
def run(f):
12+
print("\nTEST:", f.__name__)
13+
f()
14+
return f
15+
16+
17+
# CHECK-LABEL: TEST: testCPURuntimeOps
18+
@run
19+
def testCPURuntimeOps():
20+
with Context() as ctx, Location.unknown():
21+
module = Module.create()
22+
with InsertionPoint(module.body):
23+
24+
@func.FuncOp.from_py_func(F32Type.get(), IntegerType.get_signless(32))
25+
def do_print(arg1, arg2):
26+
cpuruntime.printf("Hello world %f %d", [arg1, arg2])
27+
return
28+
29+
# CHECK-LABEL: func @do_print(
30+
# CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: i32) {
31+
# CHECK: cpuruntime.printf "Hello world %f %d" %[[ARG_0]], %[[ARG_1:.*]] : f32, i32
32+
# CHECK: return
33+
# CHECK: }
34+
print(module)
35+
36+
37+
# CHECK-LABEL: TEST: testConvertToLLVM
38+
@run
39+
def testConvertToLLVM():
40+
with Context():
41+
module = Module.parse(
42+
"""
43+
module {
44+
func.func @do_print(%arg0: f32, %arg1: i32) {
45+
cpuruntime.printf "Hello world %f %d" %arg0, %arg1 : f32, i32
46+
return
47+
}
48+
}
49+
"""
50+
)
51+
pm = PassManager.parse("builtin.module(convert-cpuruntime-to-llvm)")
52+
# CHECK-NOT: cpuruntime.printf
53+
# CHECK: llvm.call @printf
54+
pm.run(module.operation)
55+
print(module)
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# See the License for the specific language governing permissions and
2+
# limitations under the License.
3+
# ===============================================================================
4+
# RUN: %python %s | FileCheck %s
5+
6+
from gc_mlir.ir import *
7+
from gc_mlir.dialects import onednn_graph, func
8+
from gc_mlir.passmanager import PassManager
9+
10+
11+
def run(f):
12+
print("\nTEST:", f.__name__)
13+
f()
14+
return f
15+
16+
17+
# CHECK-LABEL: TEST: testOneDNNGraphOps
18+
@run
19+
def testOneDNNGraphOps():
20+
with Context() as ctx, Location.unknown():
21+
module = Module.create()
22+
with InsertionPoint(module.body):
23+
f32 = F32Type.get(ctx)
24+
tensor_type = RankedTensorType.get([128, 128], f32)
25+
26+
@func.FuncOp.from_py_func(tensor_type, tensor_type)
27+
def entry(arg1, arg2):
28+
res1 = onednn_graph.matmul(
29+
arg1, arg2, bias=None, transpose_a=False, transpose_b=False
30+
)
31+
res2 = onednn_graph.add(res1, arg2)
32+
return onednn_graph.relu(res2)
33+
34+
# CHECK: [[MM:%.+]] = onednn_graph.matmul
35+
# CHECK: [[ADD:%.+]] = onednn_graph.add
36+
# CHECK: [[RELU:%.+]] = onednn_graph.relu
37+
print(module)
38+
39+
40+
# CHECK-LABEL: TEST: testConvertToLinalg
41+
@run
42+
def testConvertToLinalg():
43+
with Context():
44+
module = Module.parse(
45+
"""
46+
func.func @matmul(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> {
47+
%0 = onednn_graph.matmul %arg0, %arg1 : (tensor<128x512xbf16>, tensor<512x256xbf16>) -> tensor<128x256xbf16>
48+
return %0 : tensor<128x256xbf16>
49+
}
50+
"""
51+
)
52+
pm = PassManager.parse("builtin.module(convert-onednn-graph-to-linalg)")
53+
# CHECK: [[C0:%.+]] = arith.constant 0
54+
# CHECK: [[INIT:%.+]] = tensor.empty()
55+
# CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : bf16) outs([[INIT]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
56+
# CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs([[FILLED]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
57+
pm.run(module.operation)
58+
print(module)
File renamed without changes.

test/gc/Python/smoketest.py renamed to test/gc/python/smoketest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ def run(f):
1414
return f
1515

1616

17-
# CHECK-LABEL: TEST: testCreatetOp
17+
# CHECK-LABEL: TEST: testCreateOp
1818
# CHECK onednn_graph.add
1919
@run
20-
def testCreatetOp():
20+
def testCreateOp():
2121
with Context() as ctx, Location.unknown():
2222
module = Module.create()
2323
f32 = F32Type.get(ctx)

0 commit comments

Comments
 (0)