Skip to content

Commit bee0ddb

Browse files
[Dialect] Introduce microkernel dialect (#114)
* add microkernel dialect * remove Utils borrowed from TPP
1 parent e2e7149 commit bee0ddb

15 files changed

+539
-10
lines changed

include/gc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
add_subdirectory(Dialect)
2-
add_subdirectory(Transforms)
2+
add_subdirectory(Transforms)

include/gc/Dialect/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
add_subdirectory(CPURuntime)
22
add_subdirectory(OneDNNGraph)
33
add_subdirectory(Microkernel)
4-
add_subdirectory(Linalgx)
4+
add_subdirectory(Linalgx)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
set(LLVM_TARGET_DEFINITIONS MicrokernelEnum.td)
2+
mlir_tablegen(MicrokernelEnum.h.inc -gen-enum-decls)
3+
mlir_tablegen(MicrokernelEnum.cpp.inc -gen-enum-defs)
4+
add_public_tablegen_target(MLIRMicrokernelAttrDefIncGen)
5+
16
add_mlir_dialect(MicrokernelOps microkernel)
27
add_mlir_doc(MicrokernelOps MicrokernelOps gc/Dialect/Microkernel/ -gen-op-doc)
38
add_mlir_doc(MicrokernelDialect MicrokernelDialect gc/Dialect/Microkernel/ -gen-dialect-doc)

include/gc/Dialect/Microkernel/MicrokernelDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define GC_DIALECTS_MICROKERNELDIALECT_H
1111

1212
#include "mlir/IR/Dialect.h"
13+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1314

1415
#include "gc/Dialect/Microkernel/MicrokernelOpsDialect.h.inc"
1516

include/gc/Dialect/Microkernel/MicrokernelDialect.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ include "mlir/IR/OpBase.td"
1515
// Microkernel dialect definition.
1616
//===----------------------------------------------------------------------===//
1717

18-
def MicrokernelDialect : Dialect {
18+
def Microkernel_Dialect : Dialect {
1919
let name = "microkernel";
2020
let summary = "A dialect for microkernel abstraction.";
2121
let description = [{
22-
The dialect wraps the BRGEMM API to set up the HW context etc.
22+
This dialect contains wrappers for microkernel primitives like BRGEMM.
2323
}];
2424
let cppNamespace = "::mlir::microkernel";
25-
26-
let useDefaultTypePrinterParser = 1;
2725
}
2826

27+
//===----------------------------------------------------------------------===//
28+
// Base microkernel operation definition.
29+
//===----------------------------------------------------------------------===//
30+
31+
class Microkernel_Op<string mnemonic, list<Trait> traits = []> :
32+
Op<Microkernel_Dialect, mnemonic, traits>;
33+
2934
#endif // MICROKERNEL_DIALECT
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
//===- MicrokernelEnum.h - microkernel dialect enums ------------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef GC_DIALECTS_MICROKERNELENUM_H
10+
#define GC_DIALECTS_MICROKERNELENUM_H
11+
12+
#include "mlir/IR/Attributes.h"
13+
#include "mlir/IR/DialectImplementation.h"
14+
15+
#define GET_ATTRDEF_CLASSES
16+
#include "gc/Dialect/Microkernel/MicrokernelEnum.h.inc"
17+
18+
#endif // GC_DIALECTS_MICROKERNELENUM_H
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- MicrokernelEnum.td - microkernel dialect enum -------*- tablegen -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MICROKERNEL_ENUM
10+
#define MICROKERNEL_ENUM
11+
12+
include "mlir/IR/EnumAttr.td"
13+
include "gc/Dialect/Microkernel/MicrokernelDialect.td"
14+
15+
def Microkernel_BrgemmFlags : I64EnumAttr<
16+
"BrgemmFlags", "Flags for indicating optional behaviours of Brgemm",
17+
[
18+
I64EnumAttrCase<"NONE", 0, "none">,
19+
I64EnumAttrCase<"BETA_0", 1, "beta_0">,
20+
I64EnumAttrCase<"STRIDE", 2, "stride">,
21+
I64EnumAttrCase<"LIST", 4, "list">
22+
]> {
23+
let cppNamespace = "::mlir::microkernel";
24+
}
25+
26+
#endif // MICROKERNEL_ENUM

include/gc/Dialect/Microkernel/MicrokernelOps.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,16 @@
99
#ifndef GC_DIALECTS_MICROKERNELOPS_H
1010
#define GC_DIALECTS_MICROKERNELOPS_H
1111

12+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
13+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
14+
#include "mlir/Dialect/SCF/IR/SCF.h"
15+
#include "mlir/IR/BuiltinTypes.h"
16+
#include "mlir/IR/Dialect.h"
1217
#include "mlir/IR/OpDefinition.h"
18+
#include "mlir/Interfaces/SideEffectInterfaces.h"
19+
20+
#include "gc/Dialect/Microkernel/MicrokernelDialect.h"
21+
#include "gc/Dialect/Microkernel/MicrokernelEnum.h"
1322

1423
#define GET_OP_CLASSES
1524
#include "gc/Dialect/Microkernel/MicrokernelOps.h.inc"

include/gc/Dialect/Microkernel/MicrokernelOps.td

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,107 @@
1010
#define MICROKERNEL_OPS
1111

1212
include "MicrokernelDialect.td"
13+
include "gc/Dialect/Microkernel/MicrokernelEnum.td"
14+
include "mlir/Interfaces/SideEffectInterfaces.td"
1315

14-
#endif // MICROKERNEL_OPS
16+
class StaticMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
17+
Type<And<[MemRefOf<allowedTypes>.predicate,
18+
HasAnyRankOfPred<ranks>, HasStaticShapePred]>,
19+
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " static " #
20+
MemRefOf<allowedTypes>.summary, "::mlir::MemRefType">;
21+
22+
def Microkernel_BrgemmDispatchOp : Microkernel_Op<"brgemm.dispatch", [Pure]> {
23+
let summary = "JIT the brgemm microkernel given the parameters";
24+
let description = [{
25+
The operation has the following arguments: 1) m, n, k, lda, ldb, ldc, stride_a and stride_b.
26+
Inputs is a dense attribute of I64 elements. 2) flags carry information on
27+
the different flags that can be used for brgemm like whether beta == 0 or strided batch. For
28+
more details, see: `Microkernel_BrgemmFlags`. 3) data_types of operand A & B.
29+
Outpus is the id of JITed kernel.
30+
}];
31+
32+
let arguments = (ins
33+
ConfinedAttr<DenseI64ArrayAttr,
34+
[DenseArrayNonNegative<DenseI64ArrayAttr>]>:$inputs,
35+
TypedArrayAttrBase<Microkernel_BrgemmFlags, "brgemm flags">:$flags,
36+
TypedArrayAttrBase<TypeAttr, "brgemm dtypes">:$data_type);
37+
38+
let results = (outs I64:$results);
39+
let hasCustomAssemblyFormat = 1;
40+
let hasVerifier = 1;
41+
}
42+
43+
def Microkernel_BrgemmPrologueOp : Microkernel_Op<"brgemm.prologue"> {
44+
let summary = "Prologue before executing the JITed brgemm "
45+
"microkernel, and the context is considered core-level";
46+
let description = [{
47+
The operation has the following arguments: Input is the id of JITed kernel.
48+
There is no output.
49+
}];
50+
51+
let arguments = (ins I64:$inputs);
52+
53+
let assemblyFormat = [{
54+
`(` $inputs `)`
55+
attr-dict `:` functional-type($inputs, results)
56+
}];
57+
}
58+
59+
def Microkernel_BrgemmEpilogueOp : Microkernel_Op<"brgemm.epilogue"> {
60+
let summary = "Epilogue after executing the JITed brgemm microkernel";
61+
let description = [{
62+
The operation has the following arguments: Input is the id of JITed kernel.
63+
There is no output.
64+
}];
65+
66+
let arguments = (ins I64:$inputs);
67+
68+
let assemblyFormat = [{
69+
`(` $inputs `)`
70+
attr-dict `:` functional-type($inputs, results)
71+
}];
72+
}
73+
74+
/* A generic input type of Microkernel_BrgemmOp, allowing for `BrgemmMemRef` and I64.
75+
* The `BrgemmMemRef` should be a static MemRef, and for each operand its shape should be:
76+
* Operand A: StaticMemRefRankOf<[F32, BF16, SI8, UI8], [3]>;
77+
* Operand B (none-VNNI): StaticMemRefRankOf<[F32], [3]>;
78+
* Operand B (VNNI): StaticMemRefRankOf<[BF16, SI8, UI8], [4]>;
79+
* Operand C: StaticMemRefRankOf<[F32, SI32], [2]>;
80+
*/
81+
def BrgemmMemRefOrI64 : AnyTypeOf<[StaticMemRefRankOf<[F32, BF16, SI32, SI8, UI8], [2, 3, 4]>, I64]>;
82+
83+
def Microkernel_BrgemmOp : Microkernel_Op<"brgemm"> {
84+
let summary = "execute the JITed brgemm kernel.";
85+
let description = [{
86+
The operation has the following arguments:
87+
1) For stride mode, id of JITed kernel, MemRef of operand A/B/C, and the batch size;
88+
2) For addr mode, plus the length of addr list at the end.
89+
There is no output.
90+
}];
91+
92+
let arguments = (ins Variadic<BrgemmMemRefOrI64>:$inputs);
93+
94+
let assemblyFormat = [{
95+
`(` $inputs `)`
96+
attr-dict `:` functional-type($inputs, results)
97+
}];
98+
99+
let extraClassDeclaration = [{
100+
Value getDispatch() { return getInputs()[0]; }
101+
102+
Value getOperandA() { return getInputs()[1]; }
103+
104+
Value getOperandB() { return getInputs()[2]; }
105+
106+
Value getOutput() { return getInputs()[3]; }
107+
108+
Value getBatch() { return getInputs()[4]; }
109+
110+
Value getAddrLen() { return getInputs()[5]; }
111+
}];
112+
113+
let hasVerifier = 1;
114+
}
115+
116+
#endif // MICROKERNEL_OPS

lib/gc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ include(functions)
77
add_subdirectory(CAPI)
88
add_subdirectory(Dialect)
99
add_subdirectory(Transforms)
10-
add_subdirectory(ExecutionEngine)
10+
add_subdirectory(ExecutionEngine)

lib/gc/Dialect/Microkernel/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR)
22

33
add_mlir_dialect_library(MLIRMicrokernel
4+
MicrokernelEnum.cpp
45
MicrokernelDialect.cpp
56
MicrokernelOps.cpp
67

@@ -12,5 +13,6 @@ add_mlir_dialect_library(MLIRMicrokernel
1213

1314
LINK_LIBS PUBLIC
1415
${MLIR_LINK_COMPONENTS}
16+
GCUtilsIR
1517
)
16-
set_property(GLOBAL APPEND PROPERTY GC_DIALECT_LIBS MLIRMicrokernel)
18+
set_property(GLOBAL APPEND PROPERTY GC_DIALECT_LIBS MLIRMicrokernel)

lib/gc/Dialect/Microkernel/MicrokernelDialect.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,18 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "gc/Dialect/Microkernel/MicrokernelDialect.h"
10+
#include "gc/Dialect/Microkernel/MicrokernelEnum.h"
1011
#include "gc/Dialect/Microkernel/MicrokernelOps.h"
1112

1213
using namespace mlir;
1314
using namespace mlir::microkernel;
1415

16+
#include "gc/Dialect/Microkernel/MicrokernelOpsDialect.cpp.inc"
17+
18+
//===----------------------------------------------------------------------===//
19+
// Microkernel dialect.
20+
//===----------------------------------------------------------------------===//
21+
1522
void MicrokernelDialect::initialize() {
1623
addOperations<
1724
#define GET_OP_LIST
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//===-- MicrokernelEnum.cpp - microkernel dialect enum ----------*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/Dialect/Microkernel/MicrokernelEnum.h"
10+
#include "llvm/ADT/TypeSwitch.h"
11+
12+
using namespace mlir;
13+
using namespace mlir::microkernel;
14+
15+
#include "gc/Dialect/Microkernel/MicrokernelEnum.cpp.inc"

0 commit comments

Comments
 (0)