-
Notifications
You must be signed in to change notification settings - Fork 17
[Transform] Introduce microkernel
dialect optimization passes
#296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 98 commits
Commits
Show all changes
103 commits
Select commit
Hold shift + click to select a range
8213a9a
add microkernel dialect
1c43182
fix licenses
e522618
fix license check
88f645a
fix tidy
8a7ec98
fix lint
738ba0c
remove Utils borrowed from TPP
3f57403
fix CMake
e39ba7e
fix per comments
4acf417
add dialect lowering pass
6a1260a
remove irrelavant
1850d60
refine cmake
5bc44e4
fix brgemm runtime
1c69ee6
support linalgx::batch_reduce_matmul_vnni
2ce6f4c
fix runtime dnnl brgemm correctness
e0e8b94
fix format
921b0dc
support pattern with linalg.fill
6ec1053
move brgemm init_tiles to dispatch time
f014e73
move mlir tests to right place
f586efb
use thread_local for scratch buffer
c4e4bcf
refine memref ptr/offset extraction
f51ea4c
revert pass change
6ad33cf
fix op preceding check
a9a683a
fix utils header
619f670
accommodate to new utils
e31a6d3
fix licenses
f8100e1
update clang-tidy workflow
ae3e9f8
fix tidy
c5cbbd3
fix tidy
43b0c28
fix tidy
334be08
give teste better names
f548a18
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
bb21eef
fix clang-format
f96bcaf
minor fixes as per reviews
9f59f60
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
1075918
merge main CMake change
181cbf0
minor fixes & change GC_ENABLE_DNNL to GC_ENABLE_DNNL_API
9e69362
remove comments in cmake
1f00bae
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
9163da3
add runtime brgemm entry point
29525c9
change cmake option name GC_ENABLE_TEST_DNNL to GC_ENABLE_TEST_DNNL_API
cb9ac09
use smart ptr to manage palette buffer
51caf8d
fix clang format
44937e4
Remove pass ConvertMicrokernelToDnnlFunc
c01ad17
remove pass ConvertLinalgToMicrokernel
59a6366
add cmake error message
18ff855
use rw lock
d7e1509
fix naive lock
8e631c7
add ut for brgemm runtime
7fafd3e
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
784760d
fix clang-tidy
9bfbf7b
fix clang-tidy
290d2cb
Merge branch 'main' into haixin/microkernel_dialect_lowering
0c97fa1
Merge branch 'main' into haixin/microkernel_dialect_lowering
7d3cd4b
Revert "remove pass ConvertLinalgToMicrokernel"
5165ba8
fix as per reviews
3f6fe06
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
59d63c7
remove unnecessary header
87900a9
fix per reviews
7f94552
minor fix
6bbd864
Revert "Remove pass ConvertMicrokernelToDnnlFunc"
2ac30c6
add basic BrgemmOnTensorOp
0bdce9e
add Bufferize support
7d51603
add ExpandMicrokernel pass
bc95df3
add lowering from linalgOp to newly added brgemmOp
039f24d
fix header
ccd31cc
fix compile issue
15b09c5
minor fix on test cases
faa74db
add test & bug fixes
9463006
minor fixes & add tests
a5b99a4
fix bufferizableOpInterface
75e926b
fix BrgemmOp asm & add bufferization tests
6d445cb
fix merge issue
e285841
fix format
4e2c053
fix clang
c5d935b
fix cpu-runner test
7056588
fix per review
ef88373
fix per comments
c86c71e
fix clang-tidy
5831fef
replace some with real types
c00fcf4
add optimization pass
2cd8df0
move test mlir to right place
b3e60af
[To be tested] add pass
fa44d9a
add test & bugfix for new pass
994c8b0
fix global_ctors lookup
b13b5cd
refine test cases
af45c6c
fix util headers
b42b5b0
fix license and tidy
f195953
fix clang-tidy
dc119ed
Revert "change clang-tidy-version"
e482735
Revert "Revert "change clang-tidy-version""
c606f6d
refactor per reviews
24e8681
fix mlir test
0f1fa86
improve cpu-runner test
90efaf6
refine mlir test
a2e132e
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
e4b96f6
code & test refinements
9a0e020
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
eeb8e1f
add microkernel passes to pipeline
4610480
fix per review
9254dbb
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
a853cd8
ignore upstream linalg op with invalid input
fedd427
add TODO comments
eb94d05
fix correctness check
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
201 changes: 201 additions & 0 deletions
201
lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
//===-- EarlyDispatchMicrokernel.cpp - Dispatch before runtime --*- C++ -*-===// | ||
// | ||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Func/IR/FuncOps.h" | ||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
#include "mlir/Dialect/Linalg/Utils/Utils.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Rewrite/FrozenRewritePatternSet.h" | ||
#include "mlir/Support/LogicalResult.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include <sstream> | ||
|
||
#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h" | ||
#include "gc/Transforms/Microkernel/MicrokernelPasses.h" | ||
#include "gc/Transforms/Utils/ValueUtils.h" | ||
#include "oneapi/dnnl/dnnl_types.h" | ||
|
||
namespace mlir::microkernel { | ||
#define GEN_PASS_DEF_EARLYDISPATCHMICROKERNEL | ||
#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc" | ||
|
||
#define DEBUG_TYPE "early-dispatch-microkernel" | ||
|
||
static FailureOr<std::string> | ||
createGlobalKernelHandleName(RewriterBase &rewriter, | ||
microkernel::BrgemmDispatchOp op) { | ||
// TODO(haixin): Add runtime backend type to global name | ||
std::stringstream ss; | ||
ss << "g_dispatched_microkernel_brgemm"; | ||
|
||
auto flags = op.getFlagsAttr(); | ||
for (auto flag : flags) { | ||
auto brgemmFlag = dyn_cast_or_null<microkernel::BrgemmFlagsAttr>(flag); | ||
if (!brgemmFlag) | ||
return failure(); | ||
if (brgemmFlag.getValue() == BrgemmFlags::LIST) | ||
ciyongch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return failure(); | ||
if (brgemmFlag.getValue() == BrgemmFlags::BETA_0) | ||
ss << "_init"; | ||
} | ||
|
||
// M, N, K, LDA, LDB, LDC, stride_a, stride_b | ||
// they are in the same order with BrgemmDispatchOp inputs | ||
ArrayRef<int64_t> inputs = op.getInputsAttr().asArrayRef(); | ||
ciyongch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for (auto input : inputs) { | ||
ss << "_" << input; | ||
} | ||
|
||
// dtypeA, dtypeB | ||
auto dtypes = op.getDataType(); | ||
if (dtypes.size() != 2) | ||
return failure(); | ||
ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[0]); | ||
ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[1]); | ||
|
||
return ss.str(); | ||
} | ||
|
||
// get or create global kernel handle with initializer, identified by | ||
// `kernelName` | ||
static FailureOr<LLVM::GlobalOp> | ||
getOrCreateGlobalKernelHandle(RewriterBase &rewriter, ModuleOp module, | ||
const std::string &kernelName, | ||
microkernel::BrgemmDispatchOp op) { | ||
// Create the global at the entry of the module | ||
LLVM::GlobalOp global = module.lookupSymbol<LLVM::GlobalOp>(kernelName); | ||
if (global) | ||
return global; | ||
|
||
auto global_type = op.getResults().getType(); | ||
FlatSymbolRefAttr ctorName = | ||
SymbolRefAttr::get(module->getContext(), kernelName + "_ctor"); | ||
if (module.lookupSymbol<LLVM::LLVMFuncOp>(ctorName.getAttr())) | ||
return failure(); | ||
|
||
OpBuilder::InsertionGuard insertGuard(rewriter); | ||
rewriter.setInsertionPointToStart(module.getBody()); | ||
global = rewriter.create<LLVM::GlobalOp>( | ||
module.getLoc(), global_type, /*isConstant=*/false, | ||
LLVM::Linkage::Internal, kernelName, Attribute(), | ||
/*alignment=*/0); | ||
|
||
// create ctor for this global, which needs to be LLVMFuncOp | ||
LLVM::LLVMFuncOp ctorFunc = rewriter.create<LLVM::LLVMFuncOp>( | ||
module.getLoc(), ctorName.getValue(), | ||
LLVM::LLVMFunctionType::get(global_type, {}, false)); | ||
|
||
Location loc = ctorFunc.getLoc(); | ||
Block *entryBlock = ctorFunc.addEntryBlock(rewriter); | ||
rewriter.setInsertionPointToEnd(entryBlock); | ||
|
||
auto dispatch = op.clone(); | ||
rewriter.insert(dispatch); | ||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); | ||
rewriter.create<LLVM::StoreOp>(loc, dispatch.getResults(), globalPtr); | ||
rewriter.create<LLVM::ReturnOp>(loc, dispatch.getResults()); | ||
|
||
// initialize the gloabl with global_ctors, as the initializer of global | ||
// does not allow side effect | ||
rewriter.setInsertionPointToStart(module.getBody()); | ||
LLVM::GlobalCtorsOp global_ctors = nullptr; | ||
for (auto &op : module->getRegion(0).front()) { | ||
auto ctorOp = dyn_cast<LLVM::GlobalCtorsOp>(op); | ||
if (ctorOp) { | ||
global_ctors = ctorOp; | ||
break; | ||
} | ||
} | ||
|
||
SmallVector<Attribute> ctorRefs; | ||
SmallVector<Attribute> priorities; | ||
if (global_ctors) { | ||
auto ctorRefsAttr = global_ctors.getCtors(); | ||
auto prioritiesAttr = global_ctors.getPriorities(); | ||
for (auto &&[ctor, prior] : llvm::zip(ctorRefsAttr, prioritiesAttr)) { | ||
ctorRefs.push_back(ctor); | ||
priorities.push_back(prior); | ||
} | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "After append ctors: " << ctorRefs.size() << "\n"); | ||
} | ||
ctorRefs.push_back(ctorName); | ||
// Set new ctor's priority to lowest | ||
priorities.push_back(IntegerAttr::get(rewriter.getI32Type(), INT_MAX)); | ||
if (global_ctors) { | ||
LLVM_DEBUG(llvm::dbgs() << "Replace existing ctors\n"); | ||
// If there's existing ctors | ||
rewriter.replaceOpWithNewOp<LLVM::GlobalCtorsOp>( | ||
global_ctors, rewriter.getArrayAttr(ctorRefs), | ||
rewriter.getArrayAttr(priorities)); | ||
} else { | ||
LLVM_DEBUG(llvm::dbgs() << "Create new ctor\n"); | ||
rewriter.create<LLVM::GlobalCtorsOp>(module.getLoc(), | ||
rewriter.getArrayAttr(ctorRefs), | ||
rewriter.getArrayAttr(priorities)); | ||
} | ||
return global; | ||
} | ||
|
||
class EarlyDispatchBrgemmRewriter | ||
: public OpRewritePattern<microkernel::BrgemmDispatchOp> { | ||
public: | ||
using OpRewritePattern<microkernel::BrgemmDispatchOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(microkernel::BrgemmDispatchOp op, | ||
PatternRewriter &rewriter) const final { | ||
Location loc = op.getLoc(); | ||
ModuleOp module = op->template getParentOfType<ModuleOp>(); | ||
func::FuncOp func = op->template getParentOfType<func::FuncOp>(); | ||
|
||
auto globalKernelName = createGlobalKernelHandleName(rewriter, op); | ||
if (failed(globalKernelName)) { | ||
return rewriter.notifyMatchFailure( | ||
op, "Failed to create global kernel handle name"); | ||
} | ||
|
||
// Generate kernel handle global name | ||
auto globalKernel = | ||
getOrCreateGlobalKernelHandle(rewriter, module, *globalKernelName, op); | ||
if (failed(globalKernel)) { | ||
return rewriter.notifyMatchFailure( | ||
op, "Failed to create global kernel handle"); | ||
} | ||
|
||
// Inject global val loading into start of function | ||
auto funcBlock = &func.getBody().front(); | ||
rewriter.setInsertionPointToStart(funcBlock); | ||
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, *globalKernel); | ||
Value globalVal = rewriter.create<LLVM::LoadOp>( | ||
loc, op.getResults().getType(), globalPtr); | ||
rewriter.moveOpAfter(op, funcBlock, funcBlock->begin()); | ||
rewriter.replaceOp(op, globalVal); | ||
return success(); | ||
} | ||
}; | ||
|
||
class EarlyDispatchMicrokernel | ||
: public impl::EarlyDispatchMicrokernelBase<EarlyDispatchMicrokernel> { | ||
public: | ||
using impl::EarlyDispatchMicrokernelBase< | ||
EarlyDispatchMicrokernel>::EarlyDispatchMicrokernelBase; | ||
void runOnOperation() final { | ||
RewritePatternSet patterns(&getContext()); | ||
patterns.add<EarlyDispatchBrgemmRewriter>(&getContext()); | ||
FrozenRewritePatternSet patternSet(std::move(patterns)); | ||
|
||
// Ignore newly created Ops | ||
GreedyRewriteConfig config; | ||
config.strictMode = GreedyRewriteStrictness::ExistingOps; | ||
if (failed( | ||
applyPatternsAndFoldGreedily(getOperation(), patternSet, config))) | ||
signalPassFailure(); | ||
} | ||
}; | ||
|
||
} // namespace mlir::microkernel |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using string as the key for global kernel cache might not be the good option when considering the post-op fusion or m_mask stuff in the future.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The string name could be lengthy but I think it's pretty self-explanatory and makes pass independent of compiler's internal state. Consider such a scenario with IR going through following pipeline:
EarlyDispatchMicrokernel -> SomeLoweringPassProducingNewBrgemm -> ConvertLinalgToBrgemm -> EarlyDispatchMicrokernel
If we use global var name as cache key, we can easily dedup between first and second
EarlyDispatchMicrokernel
. I think it's hard to implement this if we keep global kernel cache as some compiler's internal state, especially under test/debug scenarios usingmlir-opt
where we might run the passes one by one in different spawns of process.For stuff like post-op fusion and mask, we can add the attr into the name as well, with predefined format, e.g.:
llvm.mlir.global internal @g_mask_1 = xxxx
llvm.mlir.global internal @g_dispatched_microkernel_brgemm_..._mask{g_mask_1}_fusing_relu() ...