Skip to content

[Transform] Introduce microkernel dialect optimization passes #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 103 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
8213a9a
add microkernel dialect
May 30, 2024
1c43182
fix licenses
Jul 8, 2024
e522618
fix license check
Jul 8, 2024
88f645a
fix tidy
Jul 8, 2024
8a7ec98
fix lint
Jul 8, 2024
738ba0c
remove Utils borrowed from TPP
Jul 24, 2024
3f57403
fix CMake
Jul 24, 2024
e39ba7e
fix per comments
Jul 26, 2024
4acf417
add dialect lowering pass
May 30, 2024
6a1260a
remove irrelavant
May 30, 2024
1850d60
refine cmake
May 30, 2024
5bc44e4
fix brgemm runtime
Jun 4, 2024
1c69ee6
support linalgx::batch_reduce_matmul_vnni
Jun 5, 2024
2ce6f4c
fix runtime dnnl brgemm correctness
Jun 11, 2024
e0e8b94
fix format
Jun 11, 2024
921b0dc
support pattern with linalg.fill
Jun 14, 2024
6ec1053
move brgemm init_tiles to dispatch time
Jun 17, 2024
f014e73
move mlir tests to right place
Jun 17, 2024
f586efb
use thread_local for scratch buffer
Jun 25, 2024
c4e4bcf
refine memref ptr/offset extraction
Jun 26, 2024
f51ea4c
revert pass change
Jun 26, 2024
6ad33cf
fix op preceding check
Jul 3, 2024
a9a683a
fix utils header
Jul 24, 2024
619f670
accommodate to new utils
Jul 24, 2024
e31a6d3
fix licenses
Jul 24, 2024
f8100e1
update clang-tidy workflow
Jul 24, 2024
ae3e9f8
fix tidy
Jul 24, 2024
c5cbbd3
fix tidy
Jul 24, 2024
43b0c28
fix tidy
Jul 25, 2024
334be08
give teste better names
Jul 30, 2024
f548a18
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 1, 2024
bb21eef
fix clang-format
Aug 1, 2024
f96bcaf
minor fixes as per reviews
Aug 6, 2024
9f59f60
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 6, 2024
1075918
merge main CMake change
Aug 6, 2024
181cbf0
minor fixes & change GC_ENABLE_DNNL to GC_ENABLE_DNNL_API
Aug 7, 2024
9e69362
remove comments in cmake
Aug 8, 2024
1f00bae
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 8, 2024
9163da3
add runtime brgemm entry point
Aug 8, 2024
29525c9
change cmake option name GC_ENABLE_TEST_DNNL to GC_ENABLE_TEST_DNNL_API
Aug 8, 2024
cb9ac09
use smart ptr to manage palette buffer
Aug 8, 2024
51caf8d
fix clang format
Aug 9, 2024
44937e4
Remove pass ConvertMicrokernelToDnnlFunc
Aug 9, 2024
c01ad17
remove pass ConvertLinalgToMicrokernel
Aug 9, 2024
59a6366
add cmake error message
Aug 9, 2024
18ff855
use rw lock
Aug 9, 2024
d7e1509
fix naive lock
Aug 9, 2024
8e631c7
add ut for brgemm runtime
Aug 14, 2024
7fafd3e
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 14, 2024
784760d
fix clang-tidy
Aug 14, 2024
9bfbf7b
fix clang-tidy
Aug 14, 2024
290d2cb
Merge branch 'main' into haixin/microkernel_dialect_lowering
Aug 15, 2024
0c97fa1
Merge branch 'main' into haixin/microkernel_dialect_lowering
Aug 19, 2024
7d3cd4b
Revert "remove pass ConvertLinalgToMicrokernel"
Aug 9, 2024
5165ba8
fix as per reviews
Aug 20, 2024
3f6fe06
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 20, 2024
59d63c7
remove unnecessary header
Aug 20, 2024
87900a9
fix per reviews
Aug 21, 2024
7f94552
minor fix
Aug 21, 2024
6bbd864
Revert "Remove pass ConvertMicrokernelToDnnlFunc"
Aug 21, 2024
2ac30c6
add basic BrgemmOnTensorOp
Jul 23, 2024
0bdce9e
add Bufferize support
Jul 23, 2024
7d51603
add ExpandMicrokernel pass
Jul 24, 2024
bc95df3
add lowering from linalgOp to newly added brgemmOp
Jul 25, 2024
039f24d
fix header
Jul 25, 2024
ccd31cc
fix compile issue
Jul 26, 2024
15b09c5
minor fix on test cases
Jul 26, 2024
faa74db
add test & bug fixes
Jul 29, 2024
9463006
minor fixes & add tests
Jul 30, 2024
a5b99a4
fix bufferizableOpInterface
Aug 14, 2024
75e926b
fix BrgemmOp asm & add bufferization tests
Aug 14, 2024
6d445cb
fix merge issue
Aug 21, 2024
e285841
fix format
Aug 21, 2024
4e2c053
fix clang
Aug 21, 2024
c5d935b
fix cpu-runner test
Aug 22, 2024
7056588
fix per review
Aug 23, 2024
ef88373
fix per comments
Aug 26, 2024
c86c71e
fix clang-tidy
Aug 26, 2024
5831fef
replace some with real types
Aug 26, 2024
c00fcf4
add optimization pass
May 30, 2024
2cd8df0
move test mlir to right place
Jun 17, 2024
b3e60af
[To be tested] add pass
Jun 18, 2024
fa44d9a
add test & bugfix for new pass
Jun 19, 2024
994c8b0
fix global_ctors lookup
Jun 19, 2024
b13b5cd
refine test cases
Jun 26, 2024
af45c6c
fix util headers
Jul 24, 2024
b42b5b0
fix license and tidy
Jul 24, 2024
f195953
fix clang-tidy
Jul 25, 2024
dc119ed
Revert "change clang-tidy-version"
Jul 25, 2024
e482735
Revert "Revert "change clang-tidy-version""
Jul 25, 2024
c606f6d
refactor per reviews
Jul 30, 2024
24e8681
fix mlir test
Aug 22, 2024
0f1fa86
improve cpu-runner test
Aug 22, 2024
90efaf6
refine mlir test
Aug 22, 2024
a2e132e
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 28, 2024
e4b96f6
code & test refinements
Aug 30, 2024
9a0e020
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 30, 2024
eeb8e1f
add microkernel passes to pipeline
Aug 30, 2024
4610480
fix per review
Sep 3, 2024
9254dbb
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Sep 3, 2024
a853cd8
ignore upstream linalg op with invalid input
Sep 4, 2024
fedd427
add TODO comments
Sep 4, 2024
eb94d05
fix correctness check
Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/gc/Transforms/Microkernel/MicrokernelPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,35 @@ def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::ml
"microkernel::MicrokernelDialect"];
}

def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::ModuleOp"> {
let summary = "Early dispatch microkernel during compile time";
let description = [{
Early dispatch microkernel during compile time.
}];
let dependentDialects = ["func::FuncDialect",
"memref::MemRefDialect",
"LLVM::LLVMDialect",
"microkernel::MicrokernelDialect"];
}

def MergeBranchMicrokernelContext: Pass<"merge-branch-microkernel-context", "::mlir::ModuleOp"> {
let summary = "Find and merge identical microkernel context operations in branches into one";
let description = [{
Find and merge identical microkernel context operations in branches into one.
}];
let dependentDialects = ["func::FuncDialect",
"memref::MemRefDialect"];
}

def MicrokernelInvariantCodeMotion: Pass<"microkernel-invariant-code-motion", "::mlir::ModuleOp"> {
let summary = "Hoist invariant microkernel code to avoid redundant execution";
let description = [{
Hoist invariant microkernel code to avoid redundant execution.
}];
let dependentDialects = ["func::FuncDialect",
"memref::MemRefDialect",
"LLVM::LLVMDialect",
"microkernel::MicrokernelDialect"];
}

#endif // GC_DIALECT_MICROKERNELPASSES
95 changes: 53 additions & 42 deletions lib/gc/Dialect/Microkernel/MicrokernelOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,42 @@ static LogicalResult verifyBrgemmFlags(ArrayAttr flags, Operation *op,
return success();
}

static bool isTypeSupported(Type outType, Type operandAType,
Type operandBType) {
if (!outType.isF32() && !outType.isSignedInteger(32))
return false;

if (outType.isF32()) {
if (!(operandAType.isF32() && operandBType.isF32()) &&
!(operandAType.isBF16() && operandBType.isBF16()))
return false;
}
if (outType.isSignedInteger(32)) {
if (!(operandAType.isSignedInteger(8) ||
operandAType.isUnsignedInteger(8)) &&
(operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8)))
return false;
}
return true;
}

// TODO(haixin): could use compiler-wide VNNI utils?
static bool isInVnniLayout(ShapedType type) {
if (!type.getElementType().isBF16() &&
!type.getElementType().isSignedInteger(8) &&
!type.getElementType().isUnsignedInteger(8))
return false;

auto blockingFactor = 0;
if (type.getElementType().isBF16())
blockingFactor = 2;
else if (type.getElementType().isSignedInteger(8) ||
type.getElementType().isUnsignedInteger(8))
blockingFactor = 4;

return type.getShape().back() == blockingFactor;
}

/////////////////////////////////////////////////////
// Start of BrgemmOp

Expand Down Expand Up @@ -308,9 +344,8 @@ static inline ArrayRef<int64_t> getShapedValueShape(Value val) {
assert((llvm::isa<TensorType>(val.getType()) ||
llvm::isa<MemRefType>(val.getType())) &&
"Expecting shaped value");
if (auto tensorTy = dyn_cast_or_null<TensorType>(val.getType())) {
if (auto tensorTy = dyn_cast_or_null<TensorType>(val.getType()))
return tensorTy.getShape();
}
auto memrefTy = dyn_cast_or_null<MemRefType>(val.getType());
return memrefTy.getShape();
}
Expand All @@ -331,15 +366,27 @@ LogicalResult BrgemmOp::verify() {
return op.emitOpError()
<< "expect inputs and its related info to be size 2\n";

auto elemTypeA = getElementTypeOrSelf(ins[0]);
auto elemTypeB = getElementTypeOrSelf(ins[1]);
auto elemTypeC = getElementTypeOrSelf(out);
if (!isTypeSupported(elemTypeC, elemTypeA, elemTypeB))
return op.emitOpError() << "unsupported input matrix types\n";

ArrayRef<int64_t> dimA = getShapedValueShape(ins[0]);
ArrayRef<int64_t> dimB = getShapedValueShape(ins[1]);
ArrayRef<int64_t> dimC = getShapedValueShape(out);
if (dimA.size() != 3)
return op.emitOpError() << "expect input A to be 3D\n";
if (dimB.size() != 3 && dimB.size() != 4)
return op.emitOpError() << "expect input B to be 3D or 4D\n";
if (dimB.size() == 4 && (dimB[3] != 2 && dimB[3] != 4))
return op.emitOpError() << "expect input B vnni step to be 2 or 4\n";
if (!elemTypeB.isF32()) {
if (dimB.size() != 4 ||
!isInVnniLayout(dyn_cast<ShapedType>(ins[1].getType())))
return op.emitOpError()
<< "expect a 4d VNNI input B for non-F32 operand: " << ins[1];
} else {
if (dimB.size() != 3)
return op.emitOpError()
<< "expect a 3d input B for F32 operand: " << ins[1];
}
if (dimC.size() != 2)
return op.emitOpError() << "expect input C to be 2D\n";
for (auto dim : batchDims)
Expand Down Expand Up @@ -558,42 +605,6 @@ LogicalResult BrgemmDispatchOp::verify() {
/////////////////////////////////////////////////////
// Start of BrgemmExecuteOp

// TODO(haixin): could use compiler-wide VNNI utils?
static bool isInVnniLayout(MemRefType memref) {
if (!memref.getElementType().isBF16() &&
!memref.getElementType().isSignedInteger(8) &&
!memref.getElementType().isUnsignedInteger(8))
return false;

auto blockingFactor = 0;
if (memref.getElementType().isBF16())
blockingFactor = 2;
else if (memref.getElementType().isSignedInteger(8) ||
memref.getElementType().isUnsignedInteger(8))
blockingFactor = 4;

return memref.getShape().back() == blockingFactor;
}

static bool isTypeSupported(Type outType, Type operandAType,
Type operandBType) {
if (!outType.isF32() && !outType.isSignedInteger(32))
return false;

if (outType.isF32()) {
if (!(operandAType.isF32() && operandBType.isF32()) &&
!(operandAType.isBF16() && operandBType.isBF16()))
return false;
}
if (outType.isSignedInteger(32)) {
if (!(operandAType.isSignedInteger(8) ||
operandAType.isUnsignedInteger(8)) &&
(operandBType.isSignedInteger(8) || operandBType.isUnsignedInteger(8)))
return false;
}
return true;
}

LogicalResult BrgemmExecuteOp::verify() {
BrgemmExecuteOp &brgemmOp = *this;

Expand Down
1 change: 1 addition & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_subdirectory(Utils)
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
MLIRIR
MLIRSupport
MLIRMicrokernelTransforms
MLIRBufferizationToMemRef
MLIRBufferizationPipelines)

Expand Down
5 changes: 4 additions & 1 deletion lib/gc/Transforms/Microkernel/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR)
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR MLIRMicrokernel)

include(onednn)

gc_add_mlir_dialect_library(MLIRMicrokernelTransforms
ConvertLinalgToMicrokernel.cpp
ExpandMicrokernel.cpp
ConvertMicrokernelToDnnlFunc.cpp
EarlyDispatchMicrokernel.cpp
MicrokernelInvariantCodeMotion.cpp
MergeBranchMicrokernelContext.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/
Expand Down
20 changes: 17 additions & 3 deletions lib/gc/Transforms/Microkernel/ConvertLinalgToMicrokernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,23 @@ static FailureOr<BrgemmDims> inferBrgemmDims(linalg::LinalgOp linalgOp) {
else
return failure();

OpOperand *operandA = linalgOp.getDpsInputOperands()[0];
OpOperand *operandB = linalgOp.getDpsInputOperands()[1];
Type operandBElemType = getElementTypeOrSelf(operandB->get());
if (operandBElemType.isF32()) {
if (kAffinePos.size() == 2) {
LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions for input "
"B, should be non-VNNI\n");
return failure();
}
} else {
if (kAffinePos.size() == 1) {
LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions for input "
"B, should be VNNI\n");
return failure();
}
}

LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] Candidate dims: "
<< "\n");
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] m pos in affine: " << mAffinePos
Expand All @@ -169,9 +186,6 @@ static FailureOr<BrgemmDims> inferBrgemmDims(linalg::LinalgOp linalgOp) {
LLVM_DEBUG(llvm::dbgs() << "[inferBrgemmDims] batch pos in affine: "
<< batchAffinePos << "\n");

OpOperand *operandA = linalgOp.getDpsInputOperands()[0];
OpOperand *operandB = linalgOp.getDpsInputOperands()[1];

BrgemmDims brgemmDims;

#define CHECK_GET_POS_IN_DOMAIN(dim, dimPos, operand) \
Expand Down
Loading