Skip to content

Commit a713c16

Browse files
authored
[Dialect] Add basic oneDNN Graph dialect (#43)
1 parent e0d4b65 commit a713c16

20 files changed

+436
-75
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ add_subdirectory(src)
6464
set(GC_LIB_LINKED_LIBS
6565
MLIRLinalgx
6666
MLIRMicrokernel
67-
MLIROnednnGraph
67+
MLIROneDNNGraph
6868
)
6969
add_library(graph_compiler SHARED ${GC_LIB_SOURCES})
7070
target_include_directories(graph_compiler PUBLIC ${GC_LIB_INCLUDES})

include/gc/Dialect/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
add_subdirectory(OnednnGraph)
1+
add_subdirectory(OneDNNGraph)
22
add_subdirectory(Microkernel)
33
add_subdirectory(Linalgx)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
add_mlir_dialect(OneDNNGraphOps onednn_graph)
2+
add_mlir_doc(OneDNNGraphOps OneDNNGraphOps gc/Dialect/OneDNNGraph/ -gen-op-doc)
3+
add_mlir_doc(OneDNNGraphDialect OneDNNGraphDialect gc/Dialect/OneDNNGraph/ -gen-dialect-doc)

include/gc/Dialect/OnednnGraph/OnednnGraphDialect.h renamed to include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- OnednnGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===//
1+
//===- OneDNNGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===//
22
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -10,7 +10,10 @@
1010
#define GC_DIALECTS_ONEDNNGRAPHDIALECT_H
1111

1212
#include "mlir/IR/Dialect.h"
13+
#include "mlir/IR/OpDefinition.h"
14+
#include "mlir/IR/OpImplementation.h"
1315

14-
#include "gc/Dialect/OnednnGraph/OnednnGraphOpsDialect.h.inc"
16+
#define GET_OP_CLASSES
17+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOpsDialect.h.inc"
1518

1619
#endif // GC_DIALECTS_ONEDNNGRAPHDIALECT_H

include/gc/Dialect/OnednnGraph/OnednnGraphDialect.td renamed to include/gc/Dialect/OneDNNGraph/OneDNNGraphDialect.td

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- OnednnGraphDialect.td - OneDNN input dialect --------*- tablegen -*-===//
1+
//===- OneDNNGraphDialect.td - OneDNN input dialect --------*- tablegen -*-===//
22
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -15,15 +15,13 @@ include "mlir/IR/OpBase.td"
1515
// OneDNNGraph dialect definition.
1616
//===----------------------------------------------------------------------===//
1717

18-
def OnednnGraphDialect : Dialect {
18+
def OneDNNGraphDialect : Dialect {
1919
let name = "onednn_graph";
2020
let summary = "A dialect for oneDNN Graph.";
2121
let description = [{
2222
This dialect follows oneDNN Graph Specification.
2323
}];
2424
let cppNamespace = "::mlir::onednn_graph";
25-
26-
let useDefaultTypePrinterParser = 1;
2725
}
2826

2927
#endif // ONEDNNGRAPH_DIALECT

include/gc/Dialect/OnednnGraph/OnednnGraphOps.h renamed to include/gc/Dialect/OneDNNGraph/OneDNNGraphOps.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- OnednnGraphOps.h - OneDNN input dialect ops --------------*- C++ -*-===//
1+
//===- OneDNNGraphOps.h - OneDNN input dialect ops --------------*- C++ -*-===//
22
//
33
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -9,9 +9,15 @@
99
#ifndef GC_DIALECTS_ONEDNNGRAPHOPS_H
1010
#define GC_DIALECTS_ONEDNNGRAPHOPS_H
1111

12+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "mlir/IR/BuiltinTypes.h"
15+
#include "mlir/IR/Dialect.h"
1216
#include "mlir/IR/OpDefinition.h"
17+
#include "mlir/Interfaces/InferTypeOpInterface.h"
18+
#include "mlir/Interfaces/SideEffectInterfaces.h"
1319

1420
#define GET_OP_CLASSES
15-
#include "gc/Dialect/OnednnGraph/OnednnGraphOps.h.inc"
21+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h.inc"
1622

1723
#endif // GC_DIALECTS_ONEDNNGRAPHOPS_H
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
//===- OneDNNGraphOps.td - OneDNN input dialect ops --------*- tablegen -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef ONEDNNGRAPH_OPS
10+
#define ONEDNNGRAPH_OPS
11+
12+
include "mlir/Interfaces/InferTypeOpInterface.td"
13+
include "mlir/Interfaces/SideEffectInterfaces.td"
14+
include "mlir/Interfaces/DestinationStyleOpInterface.td"
15+
include "mlir/IR/AttrTypeBase.td"
16+
include "mlir/IR/OpBase.td"
17+
include "OneDNNGraphDialect.td"
18+
include "OneDNNGraphTypes.td"
19+
20+
//===----------------------------------------------------------------------===//
21+
// Base OneDNNGraph operation definition.
22+
//===----------------------------------------------------------------------===//
23+
24+
class OneDNNGraph_Op<string mnemonic, list<Trait> traits = []> :
25+
Op<OneDNNGraphDialect, mnemonic, traits>;
26+
27+
class OneDNNGraph_ElemwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
28+
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultElementType, InferTensorType]> {
29+
let arguments = (ins OneDNNGraph_LogicalTensor:$input_0,
30+
OneDNNGraph_LogicalTensor:$input_1);
31+
let results = (outs OneDNNGraph_LogicalTensor:$result);
32+
33+
let assemblyFormat =
34+
"operands attr-dict `:` functional-type(operands, results)";
35+
}
36+
37+
class OneDNNGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
38+
OneDNNGraph_Op<mnemonic, traits # [SameOperandsAndResultType]> {
39+
let arguments = (ins OneDNNGraph_LogicalTensor:$operand);
40+
let results = (outs OneDNNGraph_LogicalTensor:$result);
41+
42+
let assemblyFormat =
43+
"operands attr-dict `:` functional-type(operands, results)";
44+
}
45+
46+
//===----------------------------------------------------------------------===//
47+
// OneDNNGraph op definitions
48+
//===----------------------------------------------------------------------===//
49+
50+
def OneDNNGraph_MatMulOp :
51+
OneDNNGraph_Op<"matmul", [SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
52+
let summary = "Generalized matrix multiplication";
53+
let description = [{
54+
`https://spec.oneapi.io/onednn-graph/latest/ops/matrix/MatMul_1.html`
55+
}];
56+
57+
let arguments = (ins OneDNNGraph_LogicalTensor:$input_a,
58+
OneDNNGraph_LogicalTensor:$input_b,
59+
Optional<OneDNNGraph_LogicalTensor>:$bias,
60+
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
61+
DefaultValuedAttr<BoolAttr, "false">:$transpose_b);
62+
let results = (outs OneDNNGraph_LogicalTensor:$result);
63+
64+
let assemblyFormat =
65+
"operands attr-dict `:` functional-type(operands, results)";
66+
}
67+
68+
def OneDNNGraph_ReLUOp : OneDNNGraph_ElemwiseUnaryOp<"relu"> {
69+
let summary = "element-wise relu";
70+
let description = [{
71+
`https://spec.oneapi.io/onednn-graph/latest/ops/activation/ReLU_1.html`
72+
}];
73+
}
74+
75+
def OneDNNGraph_AddOp : OneDNNGraph_ElemwiseBinaryOp<"add", [Commutative]> {
76+
let summary = "element-wise addition with multi-directional broadcast";
77+
let description = [{
78+
`https://spec.oneapi.io/onednn-graph/latest/ops/arithmetic/Add_1.html`
79+
}];
80+
}
81+
82+
#endif // ONEDNNGRAPH_OPS
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//===- OneDNNGraphTypes.h - OneDNN input dialect types ----------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef ONEDNNGRAPH_ONEDNNGRAPHTYPES_H
10+
#define ONEDNNGRAPH_ONEDNNGRAPHTYPES_H
11+
12+
#include "mlir/IR/BuiltinTypes.h"
13+
14+
#define GET_TYPEDEF_CLASSES
15+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOpsTypes.h.inc"
16+
17+
#endif // ONEDNNGRAPH_ONEDNNGRAPHTYPES_H
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- OneDNNGraphTypes.h - OneDNN input dialect types -----*- tablegen -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef ONEDNNGRAPH_TYPES
10+
#define ONEDNNGRAPH_TYPES
11+
12+
include "mlir/IR/BuiltinTypes.td"
13+
include "mlir/IR/AttrTypeBase.td"
14+
include "OneDNNGraphDialect.td"
15+
16+
//===----------------------------------------------------------------------===//
17+
// OneDNNGraph type definitions
18+
//===----------------------------------------------------------------------===//
19+
20+
def OneDNNGraph_DataType : AnyTypeOf<[
21+
F16,
22+
BF16,
23+
F32,
24+
SI<32>,
25+
SI<8>,
26+
UI<8>]>;
27+
28+
def OneDNNGraph_LogicalTensor : TensorOf<[OneDNNGraph_DataType]>;
29+
30+
#endif // ONEDNNGRAPH_TYPES

include/gc/Dialect/OnednnGraph/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
This file was deleted.

include/gc/Dialect/OnednnGraph/OnednnGraphOps.td

Lines changed: 0 additions & 14 deletions
This file was deleted.

lib/gc/Dialect/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
add_subdirectory(Linalgx)
22
add_subdirectory(Microkernel)
3-
add_subdirectory(OnednnGraph)
3+
add_subdirectory(OneDNNGraph)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
add_mlir_dialect_library(MLIROneDNNGraph
2+
OneDNNGraphDialect.cpp
3+
OneDNNGraphOps.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${PROJECT_SOURCE_DIR}/include/gc/Dialect/OneDNNGraph
7+
8+
DEPENDS
9+
MLIROneDNNGraphOpsIncGen
10+
11+
LINK_LIBS PUBLIC
12+
MLIRIR
13+
)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- OneDNNGraphDialect.h - OneDNN input dialect --------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphDialect.h"
10+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.h"
11+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphTypes.h"
12+
13+
using namespace mlir;
14+
using namespace mlir::onednn_graph;
15+
16+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOpsDialect.cpp.inc"
17+
18+
//===----------------------------------------------------------------------===//
19+
// OneDNNGraph dialect.
20+
//===----------------------------------------------------------------------===//
21+
22+
void OneDNNGraphDialect::initialize() {
23+
addOperations<
24+
#define GET_OP_LIST
25+
#include "gc/Dialect/OneDNNGraph/OneDNNGraphOps.cpp.inc"
26+
>();
27+
}

0 commit comments

Comments
 (0)