Skip to content

Commit fa5a1e9

Browse files
author
Longsheng Du
committed
add onednn graph dialect
1 parent 294c4bd commit fa5a1e9

File tree

13 files changed

+705
-3
lines changed

13 files changed

+705
-3
lines changed

include/gc-dialects/OnednnGraph/OnednnGraphDialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
#define GC_DIALECTS_ONEDNNGRAPHDIALECT_H
1111

1212
#include "mlir/IR/Dialect.h"
13+
#include "mlir/IR/OpDefinition.h"
14+
#include "mlir/IR/OpImplementation.h"
1315

16+
#define GET_OP_CLASSES
1417
#include "gc-dialects/OnednnGraph/OnednnGraphOpsDialect.h.inc"
1518

1619
#endif // GC_DIALECTS_ONEDNNGRAPHDIALECT_H

include/gc-dialects/OnednnGraph/OnednnGraphDialect.td

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
include "mlir/IR/OpBase.td"
1313

1414
//===----------------------------------------------------------------------===//
15-
// OneDNNGraph dialect definition.
15+
// OnednnGraph dialect definition.
1616
//===----------------------------------------------------------------------===//
1717

1818
def OnednnGraphDialect : Dialect {
@@ -22,8 +22,6 @@ def OnednnGraphDialect : Dialect {
2222
This dialect follows oneDNN Graph Specification.
2323
}];
2424
let cppNamespace = "::mlir::onednn_graph";
25-
26-
let useDefaultTypePrinterParser = 1;
2725
}
2826

2927
#endif // ONEDNNGRAPH_DIALECT

include/gc-dialects/OnednnGraph/OnednnGraphOps.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
#ifndef GC_DIALECTS_ONEDNNGRAPHOPS_H
1010
#define GC_DIALECTS_ONEDNNGRAPHOPS_H
1111

12+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "mlir/IR/BuiltinTypes.h"
15+
#include "mlir/IR/Dialect.h"
1216
#include "mlir/IR/OpDefinition.h"
17+
#include "mlir/Interfaces/InferTypeOpInterface.h"
18+
#include "mlir/Interfaces/SideEffectInterfaces.h"
1319

1420
#define GET_OP_CLASSES
1521
#include "gc-dialects/OnednnGraph/OnednnGraphOps.h.inc"

include/gc-dialects/OnednnGraph/OnednnGraphOps.td

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,102 @@
99
#ifndef ONEDNNGRAPH_OPS
1010
#define ONEDNNGRAPH_OPS
1111

12+
include "mlir/Interfaces/InferTypeOpInterface.td"
13+
include "mlir/Interfaces/SideEffectInterfaces.td"
14+
include "mlir/Interfaces/DestinationStyleOpInterface.td"
15+
include "mlir/IR/AttrTypeBase.td"
16+
include "mlir/IR/OpBase.td"
1217
include "OnednnGraphDialect.td"
18+
include "OnednnGraphTypes.td"
19+
20+
//===----------------------------------------------------------------------===//
21+
// Base OnednnGraph operation definition.
22+
//===----------------------------------------------------------------------===//
23+
24+
class OnednnGraph_Op<string mnemonic, list<Trait> traits = []> :
25+
Op<OnednnGraphDialect, mnemonic, traits>;
26+
27+
class OnednnGraph_ElemwiseBinaryOp<string mnemonic, list<Trait> traits = []> :
28+
OnednnGraph_Op<mnemonic, traits # [SameOperandsAndResultElementType, InferTensorType]> {
29+
let arguments = (ins OnednnGraph_LogicalTensor:$input_0,
30+
OnednnGraph_LogicalTensor:$input_1);
31+
let results = (outs OnednnGraph_LogicalTensor:$result);
32+
33+
let assemblyFormat =
34+
"operands attr-dict `:` functional-type(operands, results)";
35+
}
36+
37+
class OnednnGraph_ElemwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
38+
OnednnGraph_Op<mnemonic, traits # [SameOperandsAndResultType]> {
39+
let arguments = (ins OnednnGraph_LogicalTensor:$operand);
40+
let results = (outs OnednnGraph_LogicalTensor:$result);
41+
42+
let assemblyFormat =
43+
"operands attr-dict `:` functional-type(operands, results)";
44+
}
45+
46+
//===----------------------------------------------------------------------===//
47+
// OnednnGraph op definitions
48+
//===----------------------------------------------------------------------===//
49+
50+
def OnednnGraph_MatMulOp :
51+
OnednnGraph_Op<"matmul", [SameOperandsAndResultElementType, InferTensorTypeAdaptor]> {
52+
let summary = "Generalized matrix multiplication";
53+
let description = [{
54+
`https://spec.oneapi.io/onednn-graph/latest/ops/matrix/MatMul_1.html`
55+
}];
56+
57+
let arguments = (ins OnednnGraph_LogicalTensor:$input_a,
58+
OnednnGraph_LogicalTensor:$input_b,
59+
Optional<OnednnGraph_LogicalTensor>:$bias,
60+
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
61+
DefaultValuedAttr<BoolAttr, "false">:$transpose_b);
62+
let results = (outs OnednnGraph_LogicalTensor:$result);
63+
64+
let assemblyFormat =
65+
"operands attr-dict `:` functional-type(operands, results)";
66+
}
67+
68+
def OnednnGraph_ReLUOp : OnednnGraph_ElemwiseUnaryOp<"relu"> {
69+
let summary = "element-wise relu";
70+
let description = [{
71+
`https://spec.oneapi.io/onednn-graph/latest/ops/activation/ReLU_1.html`
72+
}];
73+
}
74+
75+
// def OnednnGraph_ExpOp : OnednnGraph_ElemwiseUnaryOp<"exp"> {
76+
// let summary = "element-wise exp";
77+
// let description = [{
78+
// `https://spec.oneapi.io/onednn-graph/latest/ops/activation/Exp_1.html`
79+
// }];
80+
// }
81+
82+
def OnednnGraph_AddOp : OnednnGraph_ElemwiseBinaryOp<"add", [Commutative]> {
83+
let summary = "element-wise addition with multi-directional broadcast";
84+
let description = [{
85+
`https://spec.oneapi.io/onednn-graph/latest/ops/arithmetic/Add_1.html`
86+
}];
87+
}
88+
89+
// def OnednnGraph_SubOp : OnednnGraph_ElemwiseBinaryOp<"subtract"> {
90+
// let summary = "element-wise subtraction with multi-directional broadcast";
91+
// let description = [{
92+
// `https://spec.oneapi.io/onednn-graph/latest/ops/arithmetic/Subtract_1.html`
93+
// }];
94+
// }
95+
96+
// def OnednnGraph_MulOp : OnednnGraph_ElemwiseBinaryOp<"multiply", [Commutative]> {
97+
// let summary = "element-wise multiplication with multi-directional broadcast";
98+
// let description = [{
99+
// `https://spec.oneapi.io/onednn-graph/latest/ops/arithmetic/Multiply_1.html`
100+
// }];
101+
// }
102+
103+
// def OnednnGraph_DivOp : OnednnGraph_ElemwiseBinaryOp<"divide"> {
104+
// let summary = "element-wise division with multi-directional broadcast";
105+
// let description = [{
106+
// `https://spec.oneapi.io/onednn-graph/latest/ops/arithmetic/Divide_1.html`
107+
// }];
108+
// }
13109

14110
#endif // ONEDNNGRAPH_OPS
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
//===- OnednnGraphTypes.h - OneDNN input dialect types ----------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef ONEDNNGRAPH_ONEDNNGRAPHTYPES_H
10+
#define ONEDNNGRAPH_ONEDNNGRAPHTYPES_H
11+
12+
#include "mlir/IR/BuiltinTypes.h"
13+
14+
#define GET_TYPEDEF_CLASSES
15+
#include "gc-dialects/OnednnGraph/OnednnGraphOpsTypes.h.inc"
16+
17+
#endif // ONEDNNGRAPH_ONEDNNGRAPHTYPES_H
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===- OnednnGraphTypes.h - OneDNN input dialect types -----*- tablegen -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef ONEDNNGRAPH_TYPES
10+
#define ONEDNNGRAPH_TYPES
11+
12+
include "mlir/IR/BuiltinTypes.td"
13+
include "mlir/IR/AttrTypeBase.td"
14+
include "OnednnGraphDialect.td"
15+
16+
//===----------------------------------------------------------------------===//
17+
// OnednnGraph type definitions
18+
//===----------------------------------------------------------------------===//
19+
20+
def OnednnGraph_DataType : AnyTypeOf<[
21+
F16,
22+
BF16,
23+
F32,
24+
SI<32>,
25+
SI<8>,
26+
UI<8>]>;
27+
28+
def OnednnGraph_FloatDataType : AnyTypeOf<[
29+
F32,
30+
BF16,
31+
F16]>;
32+
33+
def OnednnGraph_LogicalTensor : TensorOf<[OnednnGraph_DataType]>;
34+
def OnednnGraph_FloatLogicalTensor : TensorOf<[OnednnGraph_FloatDataType]>;
35+
36+
#endif // ONEDNNGRAPH_TYPES

include/gc-dialects/Passes.td

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,18 @@ def TileLinalgNamed : Pass<"tile-named-linalg", "func::FuncOp"> {
1717
["linalg::LinalgDialect", "scf::SCFDialect", "tensor::TensorDialect"];
1818
}
1919

20+
def ConvertOnednnGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
21+
let summary = "Lower the operations from the oneDNN Graph dialect into Linalg";
22+
let description = [{
23+
Lowers the `onednn_graph` ops to `linalg` ops.
24+
}];
25+
let dependentDialects = [
26+
"func::FuncDialect",
27+
"math::MathDialect",
28+
"arith::ArithDialect",
29+
"tensor::TensorDialect",
30+
"linalg::LinalgDialect"
31+
];
32+
}
33+
2034
#endif // GC_DIALECT_GC_PASSES

lib/gc-dialects/OnednnGraph/OnednnGraphDialect.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,34 @@
88

99
#include "gc-dialects/OnednnGraph/OnednnGraphDialect.h"
1010
#include "gc-dialects/OnednnGraph/OnednnGraphOps.h"
11+
#include "gc-dialects/OnednnGraph/OnednnGraphTypes.h"
12+
13+
#include "mlir/Dialect/Quant/QuantOps.h"
14+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
15+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16+
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
17+
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
18+
#include "mlir/Dialect/Utils/IndexingUtils.h"
19+
#include "mlir/IR/BuiltinTypes.h"
20+
#include "mlir/IR/DialectImplementation.h"
21+
#include "mlir/IR/Matchers.h"
22+
#include "mlir/IR/PatternMatch.h"
23+
#include "mlir/IR/TypeUtilities.h"
24+
#include "mlir/Interfaces/InferTypeOpInterface.h"
25+
#include "mlir/Transforms/InliningUtils.h"
26+
#include "llvm/ADT/APFloat.h"
27+
#include "llvm/ADT/DenseMap.h"
28+
#include "llvm/ADT/TypeSwitch.h"
1129

1230
using namespace mlir;
1331
using namespace mlir::onednn_graph;
1432

33+
#include "gc-dialects/OnednnGraph/OnednnGraphOpsDialect.cpp.inc"
34+
35+
//===----------------------------------------------------------------------===//
36+
// OnednnGraph dialect.
37+
//===----------------------------------------------------------------------===//
38+
1539
void OnednnGraphDialect::initialize() {
1640
addOperations<
1741
#define GET_OP_LIST

0 commit comments

Comments
 (0)