Skip to content

optimize thread local cache for brgemm #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 9, 2024
42 changes: 22 additions & 20 deletions lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
static std::shared_mutex g_brgemm_lock;

struct brgemm_cache_info_t {
brgemm_desc_t desc;
brgemm_desc_t *desc;
brgemm_kernel_t *kernel;
std::shared_ptr<char[]> palette;
};
Expand All @@ -66,20 +66,19 @@ static constexpr size_t SCRATCH_SIZE = 2 * 4096;
// TODO(haixin): need to use custom thread management for scratch in the future?
static thread_local char scratch[SCRATCH_SIZE] = {0};

static std::unordered_map<int64_t, brgemm_cache_info_t> &get_tl_cache() {
thread_local std::unordered_map<int64_t, brgemm_cache_info_t> tl_cache;
static std::vector<brgemm_cache_info_t> &get_tl_cache() {
thread_local std::vector<brgemm_cache_info_t> tl_cache;
return tl_cache;
}
brgemm_desc_t desc;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using a global temporary desc without syncronization?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for save the desc object's address, which will be used in g_cache.push_back(brgemm_cache_info_t{&desc, kernel, palette_buffer});. Otherwise, the address will be invalid when dispatch func return.

Copy link
Contributor

@huanghaixin008 huanghaixin008 Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But now every &desc points to the same object, and the brgemm_desc_init (will modify the global desc) has no synchronization at all?
If we really need to use pointer of desc, I think we can appoint an maximal amount of dispatched kernel, and reserve desc vector pool of that size to avoid vector resizing, then we can use pointer to vector element.


extern "C" {

int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
int64_t LDB, int64_t LDC, int64_t stride_a,
int64_t stride_b, float beta, int64_t dtypeA,
int64_t dtypeB) {
brgemm_desc_t desc;
brgemm_kernel_t *kernel;

auto dnnl_dtypeA = static_cast<dnnl_data_type_t>(dtypeA);
auto dnnl_dtypeB = static_cast<dnnl_data_type_t>(dtypeB);
int64_t dtypeA_size = dnnl::impl::types::data_type_size(dnnl_dtypeA);
Expand Down Expand Up @@ -111,21 +110,24 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
}

write_lock_guard_t g(g_brgemm_lock);
g_cache.push_back(brgemm_cache_info_t{desc, kernel, palette_buffer});
g_cache.push_back(brgemm_cache_info_t{&desc, kernel, palette_buffer});
return g_cache.size() - 1;
}

void dnnl_brgemm_tileconfig(int64_t kernel_idx) {
assert(kernel_idx >= 0 && "Invalid kernel handler");
auto &tl_cache = get_tl_cache();
auto it = tl_cache.find(kernel_idx);
if (it == tl_cache.end()) {
if (kernel_idx >= (int64_t)tl_cache.size() ||
tl_cache[kernel_idx].kernel == nullptr) {
read_lock_guard_t g(g_brgemm_lock);
assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler");
it = tl_cache.insert({kernel_idx, g_cache[kernel_idx]}).first;
if (kernel_idx >= (int64_t)tl_cache.size()) {
tl_cache.resize(kernel_idx + 1);
}
tl_cache[kernel_idx] = g_cache[kernel_idx];
}
brgemm_desc_t &desc = it->second.desc;
char *palette_buffer = it->second.palette.get();
brgemm_desc_t &desc = *tl_cache[kernel_idx].desc;
char *palette_buffer = tl_cache[kernel_idx].palette.get();

if (!desc.is_tmm) {
return;
Expand All @@ -147,17 +149,17 @@ void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset,
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
int num) {
auto &tl_cache = get_tl_cache();
if (tl_cache.find(kernel_idx) == tl_cache.end()) {
if (kernel_idx >= (int64_t)tl_cache.size() ||
tl_cache[kernel_idx].kernel == nullptr) {
read_lock_guard_t g(g_brgemm_lock);
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
"Invalid kernel handler");
auto updated_cache =
tl_cache.insert(std::make_pair(kernel_idx, g_cache[kernel_idx]));
assert(updated_cache.second && "insert into thread local cache");
assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler");
if (kernel_idx >= (int64_t)tl_cache.size()) {
tl_cache.resize(kernel_idx + 1);
}
tl_cache[kernel_idx] = g_cache[kernel_idx];
}
auto it = tl_cache.find(kernel_idx);
brgemm_kernel_t *kernel = it->second.kernel;
brgemm_desc_t *desc_ptr = &it->second.desc;
brgemm_kernel_t *kernel = tl_cache[kernel_idx].kernel;
brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc;

assert(kernel && "Invalid brgemm kernel pointer");
assert(desc_ptr && "Invalid brgemm descriptor pointer");
Expand Down