Skip to content

[Binding] enable binding of cpuruntime dialect #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/gc-c/Dialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ extern "C" {
#endif

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(OneDNNGraph, onednn_graph);
MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(CPURuntime, cpuruntime);

#ifdef __cplusplus
}
Expand Down
1 change: 1 addition & 0 deletions include/gc-c/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
extern "C" {
#endif

#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.h.inc"
#include "gc/Transforms/Passes.capi.h.inc"
#ifdef __cplusplus
}
Expand Down
4 changes: 2 additions & 2 deletions include/gc/Dialect/OneDNNGraph/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
add_mlir_dialect(OneDNNGraphOps onednn_graph)
# add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc/Dialect/OneDNNGraph/ -gen-op-doc)
# add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc/Dialect/OneDNNGraph/ -gen-dialect-doc)
add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc/Dialect/OneDNNGraph/ -gen-op-doc)
add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc/Dialect/OneDNNGraph/ -gen-dialect-doc)
2 changes: 2 additions & 0 deletions lib/gc/CAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@ add_mlir_public_c_api_library(GcCAPI
Passes.cpp
LINK_LIBS PUBLIC
MLIROneDNNGraph
MLIRCPURuntimeDialect
GCPasses
MLIRCPURuntimeTransforms
)
6 changes: 5 additions & 1 deletion lib/gc/CAPI/Dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
//===----------------------------------------------------------------------===//

#include "gc-c/Dialects.h"
#include "gc/Dialect/CPURuntime/IR/CPURuntimeDialect.h"
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
#include "mlir/CAPI/Registration.h"

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(OneDNNGraph, onednn_graph,
mlir::onednn_graph::OneDNNGraphDialect)
mlir::onednn_graph::OneDNNGraphDialect)

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(CPURuntime, cpuruntime,
mlir::cpuruntime::CPURuntimeDialect)
4 changes: 4 additions & 0 deletions lib/gc/CAPI/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,20 @@
//===----------------------------------------------------------------------===//

#include "gc/Transforms/Passes.h"
#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.h"
#include "mlir-c/Pass.h"
#include "mlir/CAPI/Pass.h"

#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.h.inc"
#include "gc/Transforms/Passes.capi.h.inc"
using namespace mlir::gc;
using namespace mlir::cpuruntime;

#ifdef __cplusplus
extern "C" {
#endif

#include "gc/Dialect/CPURuntime/Transforms/CPURuntimePasses.capi.cpp.inc"
#include "gc/Transforms/Passes.capi.cpp.inc"

#ifdef __cplusplus
Expand Down
9 changes: 9 additions & 0 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ declare_mlir_dialect_python_bindings(
DIALECT_NAME onednn_graph
)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT GcPythonSources
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/gc_mlir"
TD_FILE dialects/CPURuntimeOps.td
SOURCES
dialects/cpuruntime.py
DIALECT_NAME cpuruntime
)

declare_mlir_python_extension(GcPythonSources.Extension
MODULE_NAME _gc_mlir
ADD_TO_PARENT GcPythonSources
Expand Down
16 changes: 16 additions & 0 deletions python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,20 @@ PYBIND11_MODULE(_gc_mlir, m) {
}
},
py::arg("context") = py::none(), py::arg("load") = true);

//===----------------------------------------------------------------------===//
// CPURuntime
//===----------------------------------------------------------------------===//
mlirRegisterCPURuntimePasses();
auto cpuruntimeM = m.def_submodule("cpuruntime");
cpuruntimeM.def(
"register_dialect",
[](MlirContext context, bool load) {
MlirDialectHandle dialect = mlirGetDialectHandle__cpuruntime__();
mlirDialectHandleRegisterDialect(dialect, context);
if (load) {
mlirDialectHandleLoadDialect(dialect, context);
}
},
py::arg("context") = py::none(), py::arg("load") = true);
}
8 changes: 5 additions & 3 deletions python/gc_mlir/_mlir_libs/_site_initialize_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@


def context_init_hook(context):
from ._gc_mlir.onednn_graph import register_dialect

register_dialect(context)
from ._gc_mlir.onednn_graph import register_dialect as register_onednn_graph_dialect
from ._gc_mlir.cpuruntime import register_dialect as register_cpuruntime_dialect

register_onednn_graph_dialect(context)
register_cpuruntime_dialect(context)
14 changes: 14 additions & 0 deletions python/gc_mlir/dialects/CPURuntimeOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//===-- CPURuntimeOps.td - Entry point for bindings --------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_CPURUNTIME_OPS
#define PYTHON_BINDINGS_CPURUNTIME_OPS

include "gc/Dialect/CPURuntime/IR/CPURuntimeOps.td"

#endif // PYTHON_BINDINGS_CPURUNTIME_OPS
2 changes: 1 addition & 1 deletion python/gc_mlir/dialects/OneDNNGraphOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.td"

#endif
#endif // PYTHON_BINDINGS_ONEDNNGRAPH_OPS
10 changes: 10 additions & 0 deletions python/gc_mlir/dialects/cpuruntime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# ===-- cpuruntime.py - MLIR Python source -------------------*- Python -*-===#
#
# This file is licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
# ===-----------------------------------------------------------------------===#

from ._cpuruntime_ops_gen import *
from .._mlir_libs._gc_mlir.cpuruntime import *
55 changes: 55 additions & 0 deletions test/gc/python/dialects/cpuruntime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================
# RUN: %python %s | FileCheck %s

from gc_mlir.dialects import cpuruntime, func
from gc_mlir.ir import *
from gc_mlir.passmanager import PassManager


def run(f):
print("\nTEST:", f.__name__)
f()
return f


# CHECK-LABEL: TEST: testCPURuntimeOps
@run
def testCPURuntimeOps():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):

@func.FuncOp.from_py_func(F32Type.get(), IntegerType.get_signless(32))
def do_print(arg1, arg2):
cpuruntime.printf("Hello world %f %d", [arg1, arg2])
return

# CHECK-LABEL: func @do_print(
# CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: i32) {
# CHECK: cpuruntime.printf "Hello world %f %d" %[[ARG_0]], %[[ARG_1:.*]] : f32, i32
# CHECK: return
# CHECK: }
print(module)


# CHECK-LABEL: TEST: testConvertToLLVM
@run
def testConvertToLLVM():
with Context():
module = Module.parse(
"""
module {
func.func @do_print(%arg0: f32, %arg1: i32) {
cpuruntime.printf "Hello world %f %d" %arg0, %arg1 : f32, i32
return
}
}
"""
)
pm = PassManager.parse("builtin.module(convert-cpuruntime-to-llvm)")
# CHECK-NOT: cpuruntime.printf
# CHECK: llvm.call @printf
pm.run(module.operation)
print(module)
58 changes: 58 additions & 0 deletions test/gc/python/dialects/onednn_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================
# RUN: %python %s | FileCheck %s

from gc_mlir.ir import *
from gc_mlir.dialects import onednn_graph, func
from gc_mlir.passmanager import PassManager


def run(f):
print("\nTEST:", f.__name__)
f()
return f


# CHECK-LABEL: TEST: testOneDNNGraphOps
@run
def testOneDNNGraphOps():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
f32 = F32Type.get(ctx)
tensor_type = RankedTensorType.get([128, 128], f32)

@func.FuncOp.from_py_func(tensor_type, tensor_type)
def entry(arg1, arg2):
res1 = onednn_graph.matmul(
arg1, arg2, bias=None, transpose_a=False, transpose_b=False
)
res2 = onednn_graph.add(res1, arg2)
return onednn_graph.relu(res2)

# CHECK: [[MM:%.+]] = onednn_graph.matmul
# CHECK: [[ADD:%.+]] = onednn_graph.add
# CHECK: [[RELU:%.+]] = onednn_graph.relu
print(module)


# CHECK-LABEL: TEST: testConvertToLinalg
@run
def testConvertToLinalg():
with Context():
module = Module.parse(
"""
func.func @matmul(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> {
%0 = onednn_graph.matmul %arg0, %arg1 : (tensor<128x512xbf16>, tensor<512x256xbf16>) -> tensor<128x256xbf16>
return %0 : tensor<128x256xbf16>
}
"""
)
pm = PassManager.parse("builtin.module(convert-onednn-graph-to-linalg)")
# CHECK: [[C0:%.+]] = arith.constant 0
# CHECK: [[INIT:%.+]] = tensor.empty()
# CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : bf16) outs([[INIT]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
# CHECK: linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs([[FILLED]] : tensor<128x256xbf16>) -> tensor<128x256xbf16>
pm.run(module.operation)
print(module)
File renamed without changes.
4 changes: 2 additions & 2 deletions test/gc/Python/smoketest.py → test/gc/python/smoketest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def run(f):
return f


# CHECK-LABEL: TEST: testCreatetOp
# CHECK-LABEL: TEST: testCreateOp
# CHECK onednn_graph.add
@run
def testCreatetOp():
def testCreateOp():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get(ctx)
Expand Down