Skip to content

[Transform] Introduce microkernel dialect optimization passes #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 103 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 98 commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
8213a9a
add microkernel dialect
May 30, 2024
1c43182
fix licenses
Jul 8, 2024
e522618
fix license check
Jul 8, 2024
88f645a
fix tidy
Jul 8, 2024
8a7ec98
fix lint
Jul 8, 2024
738ba0c
remove Utils borrowed from TPP
Jul 24, 2024
3f57403
fix CMake
Jul 24, 2024
e39ba7e
fix per comments
Jul 26, 2024
4acf417
add dialect lowering pass
May 30, 2024
6a1260a
remove irrelavant
May 30, 2024
1850d60
refine cmake
May 30, 2024
5bc44e4
fix brgemm runtime
Jun 4, 2024
1c69ee6
support linalgx::batch_reduce_matmul_vnni
Jun 5, 2024
2ce6f4c
fix runtime dnnl brgemm correctness
Jun 11, 2024
e0e8b94
fix format
Jun 11, 2024
921b0dc
support pattern with linalg.fill
Jun 14, 2024
6ec1053
move brgemm init_tiles to dispatch time
Jun 17, 2024
f014e73
move mlir tests to right place
Jun 17, 2024
f586efb
use thread_local for scratch buffer
Jun 25, 2024
c4e4bcf
refine memref ptr/offset extraction
Jun 26, 2024
f51ea4c
revert pass change
Jun 26, 2024
6ad33cf
fix op preceding check
Jul 3, 2024
a9a683a
fix utils header
Jul 24, 2024
619f670
accommodate to new utils
Jul 24, 2024
e31a6d3
fix licenses
Jul 24, 2024
f8100e1
update clang-tidy workflow
Jul 24, 2024
ae3e9f8
fix tidy
Jul 24, 2024
c5cbbd3
fix tidy
Jul 24, 2024
43b0c28
fix tidy
Jul 25, 2024
334be08
give teste better names
Jul 30, 2024
f548a18
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 1, 2024
bb21eef
fix clang-format
Aug 1, 2024
f96bcaf
minor fixes as per reviews
Aug 6, 2024
9f59f60
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 6, 2024
1075918
merge main CMake change
Aug 6, 2024
181cbf0
minor fixes & change GC_ENABLE_DNNL to GC_ENABLE_DNNL_API
Aug 7, 2024
9e69362
remove comments in cmake
Aug 8, 2024
1f00bae
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 8, 2024
9163da3
add runtime brgemm entry point
Aug 8, 2024
29525c9
change cmake option name GC_ENABLE_TEST_DNNL to GC_ENABLE_TEST_DNNL_API
Aug 8, 2024
cb9ac09
use smart ptr to manage palette buffer
Aug 8, 2024
51caf8d
fix clang format
Aug 9, 2024
44937e4
Remove pass ConvertMicrokernelToDnnlFunc
Aug 9, 2024
c01ad17
remove pass ConvertLinalgToMicrokernel
Aug 9, 2024
59a6366
add cmake error message
Aug 9, 2024
18ff855
use rw lock
Aug 9, 2024
d7e1509
fix naive lock
Aug 9, 2024
8e631c7
add ut for brgemm runtime
Aug 14, 2024
7fafd3e
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 14, 2024
784760d
fix clang-tidy
Aug 14, 2024
9bfbf7b
fix clang-tidy
Aug 14, 2024
290d2cb
Merge branch 'main' into haixin/microkernel_dialect_lowering
Aug 15, 2024
0c97fa1
Merge branch 'main' into haixin/microkernel_dialect_lowering
Aug 19, 2024
7d3cd4b
Revert "remove pass ConvertLinalgToMicrokernel"
Aug 9, 2024
5165ba8
fix as per reviews
Aug 20, 2024
3f6fe06
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 20, 2024
59d63c7
remove unnecessary header
Aug 20, 2024
87900a9
fix per reviews
Aug 21, 2024
7f94552
minor fix
Aug 21, 2024
6bbd864
Revert "Remove pass ConvertMicrokernelToDnnlFunc"
Aug 21, 2024
2ac30c6
add basic BrgemmOnTensorOp
Jul 23, 2024
0bdce9e
add Bufferize support
Jul 23, 2024
7d51603
add ExpandMicrokernel pass
Jul 24, 2024
bc95df3
add lowering from linalgOp to newly added brgemmOp
Jul 25, 2024
039f24d
fix header
Jul 25, 2024
ccd31cc
fix compile issue
Jul 26, 2024
15b09c5
minor fix on test cases
Jul 26, 2024
faa74db
add test & bug fixes
Jul 29, 2024
9463006
minor fixes & add tests
Jul 30, 2024
a5b99a4
fix bufferizableOpInterface
Aug 14, 2024
75e926b
fix BrgemmOp asm & add bufferization tests
Aug 14, 2024
6d445cb
fix merge issue
Aug 21, 2024
e285841
fix format
Aug 21, 2024
4e2c053
fix clang
Aug 21, 2024
c5d935b
fix cpu-runner test
Aug 22, 2024
7056588
fix per review
Aug 23, 2024
ef88373
fix per comments
Aug 26, 2024
c86c71e
fix clang-tidy
Aug 26, 2024
5831fef
replace some with real types
Aug 26, 2024
c00fcf4
add optimization pass
May 30, 2024
2cd8df0
move test mlir to right place
Jun 17, 2024
b3e60af
[To be tested] add pass
Jun 18, 2024
fa44d9a
add test & bugfix for new pass
Jun 19, 2024
994c8b0
fix global_ctors lookup
Jun 19, 2024
b13b5cd
refine test cases
Jun 26, 2024
af45c6c
fix util headers
Jul 24, 2024
b42b5b0
fix license and tidy
Jul 24, 2024
f195953
fix clang-tidy
Jul 25, 2024
dc119ed
Revert "change clang-tidy-version"
Jul 25, 2024
e482735
Revert "Revert "change clang-tidy-version""
Jul 25, 2024
c606f6d
refactor per reviews
Jul 30, 2024
24e8681
fix mlir test
Aug 22, 2024
0f1fa86
improve cpu-runner test
Aug 22, 2024
90efaf6
refine mlir test
Aug 22, 2024
a2e132e
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 28, 2024
e4b96f6
code & test refinements
Aug 30, 2024
9a0e020
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Aug 30, 2024
eeb8e1f
add microkernel passes to pipeline
Aug 30, 2024
4610480
fix per review
Sep 3, 2024
9254dbb
Merge branch 'main' of https://github.com/intel/graph-compiler into h…
Sep 3, 2024
a853cd8
ignore upstream linalg op with invalid input
Sep 4, 2024
fedd427
add TODO comments
Sep 4, 2024
eb94d05
fix correctness check
Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/gc/Transforms/Microkernel/MicrokernelPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,35 @@ def ConvertMicrokernelToDnnlFunc: Pass<"convert-microkernel-to-dnnl-func", "::ml
"microkernel::MicrokernelDialect"];
}

def EarlyDispatchMicrokernel: Pass<"early-dispatch-microkernel", "::mlir::ModuleOp"> {
let summary = "Early dispatch microkernel during compile time";
let description = [{
Early dispatch microkernel during compile time.
}];
let dependentDialects = ["func::FuncDialect",
"memref::MemRefDialect",
"LLVM::LLVMDialect",
"microkernel::MicrokernelDialect"];
}

def MergeBranchMicrokernelContext: Pass<"merge-branch-microkernel-context", "::mlir::ModuleOp"> {
let summary = "Find and merge identical microkernel context operations in branches into one";
let description = [{
Find and merge identical microkernel context operations in branches into one.
}];
let dependentDialects = ["func::FuncDialect",
"memref::MemRefDialect"];
}

def MicrokernelInvariantCodeMotion: Pass<"microkernel-invariant-code-motion", "::mlir::ModuleOp"> {
let summary = "Hoist invariant microkernel code to avoid redundant execution";
let description = [{
Hoist invariant microkernel code to avoid redundant execution.
}];
let dependentDialects = ["func::FuncDialect",
"memref::MemRefDialect",
"LLVM::LLVMDialect",
"microkernel::MicrokernelDialect"];
}

#endif // GC_DIALECT_MICROKERNELPASSES
1 change: 1 addition & 0 deletions lib/gc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_subdirectory(Utils)
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS
MLIRIR
MLIRSupport
MLIRMicrokernelTransforms
MLIRBufferizationToMemRef
MLIRBufferizationPipelines)

Expand Down
5 changes: 4 additions & 1 deletion lib/gc/Transforms/Microkernel/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR)
gc_set_mlir_link_components(MLIR_LINK_COMPONENTS MLIRIR MLIRMicrokernel)

include(onednn)

gc_add_mlir_dialect_library(MLIRMicrokernelTransforms
ConvertLinalgToMicrokernel.cpp
ExpandMicrokernel.cpp
ConvertMicrokernelToDnnlFunc.cpp
EarlyDispatchMicrokernel.cpp
MicrokernelInvariantCodeMotion.cpp
MergeBranchMicrokernelContext.cpp

ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/
Expand Down
201 changes: 201 additions & 0 deletions lib/gc/Transforms/Microkernel/EarlyDispatchMicrokernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
//===-- EarlyDispatchMicrokernel.cpp - Dispatch before runtime --*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include <sstream>

#include "gc/Transforms/Microkernel/BrgemmRuntimeUtils.h"
#include "gc/Transforms/Microkernel/MicrokernelPasses.h"
#include "gc/Transforms/Utils/ValueUtils.h"
#include "oneapi/dnnl/dnnl_types.h"

namespace mlir::microkernel {
#define GEN_PASS_DEF_EARLYDISPATCHMICROKERNEL
#include "gc/Transforms/Microkernel/MicrokernelPasses.h.inc"

#define DEBUG_TYPE "early-dispatch-microkernel"

static FailureOr<std::string>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using string as the key for global kernel cache might not be the good option when considering the post-op fusion or m_mask stuff in the future.

Copy link
Contributor Author

@huanghaixin008 huanghaixin008 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string name could be lengthy but I think it's pretty self-explanatory and makes pass independent of compiler's internal state. Consider such a scenario with IR going through following pipeline:
EarlyDispatchMicrokernel -> SomeLoweringPassProducingNewBrgemm -> ConvertLinalgToBrgemm -> EarlyDispatchMicrokernel
If we use global var name as cache key, we can easily dedup between first and second EarlyDispatchMicrokernel. I think it's hard to implement this if we keep global kernel cache as some compiler's internal state, especially under test/debug scenarios using mlir-opt where we might run the passes one by one in different spawns of process.

For stuff like post-op fusion and mask, we can add the attr into the name as well, with predefined format, e.g.:
llvm.mlir.global internal @g_mask_1 = xxxx
llvm.mlir.global internal @g_dispatched_microkernel_brgemm_..._mask{g_mask_1}_fusing_relu() ...

createGlobalKernelHandleName(RewriterBase &rewriter,
microkernel::BrgemmDispatchOp op) {
// TODO(haixin): Add runtime backend type to global name
std::stringstream ss;
ss << "g_dispatched_microkernel_brgemm";

auto flags = op.getFlagsAttr();
for (auto flag : flags) {
auto brgemmFlag = dyn_cast_or_null<microkernel::BrgemmFlagsAttr>(flag);
if (!brgemmFlag)
return failure();
if (brgemmFlag.getValue() == BrgemmFlags::LIST)
return failure();
if (brgemmFlag.getValue() == BrgemmFlags::BETA_0)
ss << "_init";
}

// M, N, K, LDA, LDB, LDC, stride_a, stride_b
// they are in the same order with BrgemmDispatchOp inputs
ArrayRef<int64_t> inputs = op.getInputsAttr().asArrayRef();
for (auto input : inputs) {
ss << "_" << input;
}

// dtypeA, dtypeB
auto dtypes = op.getDataType();
if (dtypes.size() != 2)
return failure();
ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[0]);
ss << "_" << getDnnlDataTypeVal(rewriter, dtypes[1]);

return ss.str();
}

// get or create global kernel handle with initializer, identified by
// `kernelName`
static FailureOr<LLVM::GlobalOp>
getOrCreateGlobalKernelHandle(RewriterBase &rewriter, ModuleOp module,
const std::string &kernelName,
microkernel::BrgemmDispatchOp op) {
// Create the global at the entry of the module
LLVM::GlobalOp global = module.lookupSymbol<LLVM::GlobalOp>(kernelName);
if (global)
return global;

auto global_type = op.getResults().getType();
FlatSymbolRefAttr ctorName =
SymbolRefAttr::get(module->getContext(), kernelName + "_ctor");
if (module.lookupSymbol<LLVM::LLVMFuncOp>(ctorName.getAttr()))
return failure();

OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
global = rewriter.create<LLVM::GlobalOp>(
module.getLoc(), global_type, /*isConstant=*/false,
LLVM::Linkage::Internal, kernelName, Attribute(),
/*alignment=*/0);

// create ctor for this global, which needs to be LLVMFuncOp
LLVM::LLVMFuncOp ctorFunc = rewriter.create<LLVM::LLVMFuncOp>(
module.getLoc(), ctorName.getValue(),
LLVM::LLVMFunctionType::get(global_type, {}, false));

Location loc = ctorFunc.getLoc();
Block *entryBlock = ctorFunc.addEntryBlock(rewriter);
rewriter.setInsertionPointToEnd(entryBlock);

auto dispatch = op.clone();
rewriter.insert(dispatch);
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
rewriter.create<LLVM::StoreOp>(loc, dispatch.getResults(), globalPtr);
rewriter.create<LLVM::ReturnOp>(loc, dispatch.getResults());

// initialize the gloabl with global_ctors, as the initializer of global
// does not allow side effect
rewriter.setInsertionPointToStart(module.getBody());
LLVM::GlobalCtorsOp global_ctors = nullptr;
for (auto &op : module->getRegion(0).front()) {
auto ctorOp = dyn_cast<LLVM::GlobalCtorsOp>(op);
if (ctorOp) {
global_ctors = ctorOp;
break;
}
}

SmallVector<Attribute> ctorRefs;
SmallVector<Attribute> priorities;
if (global_ctors) {
auto ctorRefsAttr = global_ctors.getCtors();
auto prioritiesAttr = global_ctors.getPriorities();
for (auto &&[ctor, prior] : llvm::zip(ctorRefsAttr, prioritiesAttr)) {
ctorRefs.push_back(ctor);
priorities.push_back(prior);
}
LLVM_DEBUG(llvm::dbgs()
<< "After append ctors: " << ctorRefs.size() << "\n");
}
ctorRefs.push_back(ctorName);
// Set new ctor's priority to lowest
priorities.push_back(IntegerAttr::get(rewriter.getI32Type(), INT_MAX));
if (global_ctors) {
LLVM_DEBUG(llvm::dbgs() << "Replace existing ctors\n");
// If there's existing ctors
rewriter.replaceOpWithNewOp<LLVM::GlobalCtorsOp>(
global_ctors, rewriter.getArrayAttr(ctorRefs),
rewriter.getArrayAttr(priorities));
} else {
LLVM_DEBUG(llvm::dbgs() << "Create new ctor\n");
rewriter.create<LLVM::GlobalCtorsOp>(module.getLoc(),
rewriter.getArrayAttr(ctorRefs),
rewriter.getArrayAttr(priorities));
}
return global;
}

class EarlyDispatchBrgemmRewriter
: public OpRewritePattern<microkernel::BrgemmDispatchOp> {
public:
using OpRewritePattern<microkernel::BrgemmDispatchOp>::OpRewritePattern;

LogicalResult matchAndRewrite(microkernel::BrgemmDispatchOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
ModuleOp module = op->template getParentOfType<ModuleOp>();
func::FuncOp func = op->template getParentOfType<func::FuncOp>();

auto globalKernelName = createGlobalKernelHandleName(rewriter, op);
if (failed(globalKernelName)) {
return rewriter.notifyMatchFailure(
op, "Failed to create global kernel handle name");
}

// Generate kernel handle global name
auto globalKernel =
getOrCreateGlobalKernelHandle(rewriter, module, *globalKernelName, op);
if (failed(globalKernel)) {
return rewriter.notifyMatchFailure(
op, "Failed to create global kernel handle");
}

// Inject global val loading into start of function
auto funcBlock = &func.getBody().front();
rewriter.setInsertionPointToStart(funcBlock);
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, *globalKernel);
Value globalVal = rewriter.create<LLVM::LoadOp>(
loc, op.getResults().getType(), globalPtr);
rewriter.moveOpAfter(op, funcBlock, funcBlock->begin());
rewriter.replaceOp(op, globalVal);
return success();
}
};

class EarlyDispatchMicrokernel
: public impl::EarlyDispatchMicrokernelBase<EarlyDispatchMicrokernel> {
public:
using impl::EarlyDispatchMicrokernelBase<
EarlyDispatchMicrokernel>::EarlyDispatchMicrokernelBase;
void runOnOperation() final {
RewritePatternSet patterns(&getContext());
patterns.add<EarlyDispatchBrgemmRewriter>(&getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));

// Ignore newly created Ops
GreedyRewriteConfig config;
config.strictMode = GreedyRewriteStrictness::ExistingOps;
if (failed(
applyPatternsAndFoldGreedily(getOperation(), patternSet, config)))
signalPassFailure();
}
};

} // namespace mlir::microkernel
Loading
Loading