Skip to content

[Transform] Introduce microkernel dialect optimization passes #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 103 commits into from
Sep 5, 2024

Conversation

huanghaixin008
Copy link
Contributor

@huanghaixin008 huanghaixin008 commented Aug 28, 2024

Tracking #297

This PR introduces following passed to optimize microkernel dialect runtime efficiency:

  • EarlyDispatchMicrokernel: Dispatch microkernel during initialization time to reduce runtime cost, and merge identical microkernel dispatch meanwhile
  • InvariantMicrokernelCodeMotion: Hoist invariant microkernel-related codes to improve performance
  • MergeBranchMicrokernelContext: Merge and hoist identical microkernel context codes out of branch if possible, enabling further hoist

return
}

// CHECK: BRGEMM DONE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check for the result instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Result correctness check added.

Comment on lines 91 to 95
auto tryAddrOfOp = dyn_cast_or_null<LLVM::AddressOfOp>(
tryLoadOp.getOperand().getDefiningOp());
if (!tryAddrOfOp)
return nullptr;
return traceDispatchInGlobalCtor(module, tryAddrOfOp.getGlobalName());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto tryAddrOfOp = dyn_cast_or_null<LLVM::AddressOfOp>(
tryLoadOp.getOperand().getDefiningOp());
if (!tryAddrOfOp)
return nullptr;
return traceDispatchInGlobalCtor(module, tryAddrOfOp.getGlobalName());
if (auto tryAddrOfOp = dyn_cast_or_null<LLVM::AddressOfOp>(
tryLoadOp.getOperand().getDefiningOp()))
return traceDispatchInGlobalCtor(module, tryAddrOfOp.getGlobalName());

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

if (callee != StringAttr::get(op->getContext(), DNNL_BRGEMM_DISPATCH_NAME))
return nullptr;
return tryCallOp;
} else if (auto tryLoadOp = dyn_cast_or_null<LLVM::LoadOp>(kernelProducer)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
} else if (auto tryLoadOp = dyn_cast_or_null<LLVM::LoadOp>(kernelProducer)) {
}
if (auto tryLoadOp = dyn_cast_or_null<LLVM::LoadOp>(kernelProducer)) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.


#define DEBUG_TYPE "early-dispatch-microkernel"

static FailureOr<std::string>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using string as the key for global kernel cache might not be the good option when considering the post-op fusion or m_mask stuff in the future.

Copy link
Contributor Author

@huanghaixin008 huanghaixin008 Sep 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The string name could be lengthy but I think it's pretty self-explanatory and makes pass independent of compiler's internal state. Consider such a scenario with IR going through following pipeline:
EarlyDispatchMicrokernel -> SomeLoweringPassProducingNewBrgemm -> ConvertLinalgToBrgemm -> EarlyDispatchMicrokernel
If we use global var name as cache key, we can easily dedup between first and second EarlyDispatchMicrokernel. I think it's hard to implement this if we keep global kernel cache as some compiler's internal state, especially under test/debug scenarios using mlir-opt where we might run the passes one by one in different spawns of process.

For stuff like post-op fusion and mask, we can add the attr into the name as well, with predefined format, e.g.:
llvm.mlir.global internal @g_mask_1 = xxxx
llvm.mlir.global internal @g_dispatched_microkernel_brgemm_..._mask{g_mask_1}_fusing_relu() ...

@ciyongch
Copy link
Contributor

ciyongch commented Sep 4, 2024

Please check the failed case:
scripts/correctness.sh: line 24: 5132 Segmentation fault (core dumped) python3 -m benchgc --verbose 0 --driver linalg --case batch_reduce_matmul --md 0:16x512x64xf32 --md 1:16x64x32xf32 --md 2:512x32xf32

@huanghaixin008
Copy link
Contributor Author

Please check the failed case: scripts/correctness.sh: line 24: 5132 Segmentation fault (core dumped) python3 -m benchgc --verbose 0 --driver linalg --case batch_reduce_matmul --md 0:16x512x64xf32 --md 1:16x64x32xf32 --md 2:512x32xf32

correctness check failed has been fixed.

@ciyongch ciyongch merged commit bc0014b into main Sep 5, 2024
6 checks passed
@lmontigny lmontigny added this to the 0.1 CPU - General milestone Sep 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Introduce microkernel dialect optimization passes
5 participants