Skip to content

optimize thread local cache for brgemm #353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 9, 2024
108 changes: 49 additions & 59 deletions lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ __attribute__((weak)) void print_verbose_header() {}
} // namespace dnnl

static constexpr int PALETTE_SIZE = 64;
static constexpr int DEFAULT_KERNEL_SIZE = 1024;
static constexpr int MAX_KERNEL_SIZE = 2048;

using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
Expand All @@ -56,81 +58,78 @@ static std::shared_mutex g_brgemm_lock;
struct brgemm_cache_info_t {
brgemm_desc_t desc;
brgemm_kernel_t *kernel;
std::shared_ptr<char[]> palette;
std::unique_ptr<char[]> palette;
};

static std::vector<brgemm_cache_info_t> g_cache;
static std::vector<brgemm_cache_info_t> g_cache(DEFAULT_KERNEL_SIZE);
static int64_t g_kernel_id = -1;

// TODO(haixin): use syscall to determine page size?
static constexpr size_t SCRATCH_SIZE = 2 * 4096;
// TODO(haixin): need to use custom thread management for scratch in the future?
static thread_local char scratch[SCRATCH_SIZE] = {0};

static std::unordered_map<int64_t, brgemm_cache_info_t> &get_tl_cache() {
thread_local std::unordered_map<int64_t, brgemm_cache_info_t> tl_cache;
return tl_cache;
}

extern "C" {

int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
int64_t LDB, int64_t LDC, int64_t stride_a,
int64_t stride_b, float beta, int64_t dtypeA,
int64_t dtypeB) {
brgemm_desc_t desc;
brgemm_kernel_t *kernel;

auto dnnl_dtypeA = static_cast<dnnl_data_type_t>(dtypeA);
auto dnnl_dtypeB = static_cast<dnnl_data_type_t>(dtypeB);
int64_t dtypeA_size = dnnl::impl::types::data_type_size(dnnl_dtypeA);
int64_t dtypeB_size = dnnl::impl::types::data_type_size(dnnl_dtypeB);
brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};

write_lock_guard_t g(g_brgemm_lock);
g_kernel_id++;
assert(g_kernel_id < MAX_KERNEL_SIZE &&
"Too many brgemm kernels are created");
if (g_kernel_id >= DEFAULT_KERNEL_SIZE) {
if (g_kernel_id >= (int64_t)g_cache.size()) {
g_cache.resize(g_kernel_id + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably have some constraints and an eviction policy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the advice. Added a constraint. I believe an eviction policy is not necessary in this case. In the vast majority of scenarios, our kernel size does not exceed 1024 entries.

}
}

dnnl::impl::status_t status = brgemm_desc_init(
&desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd,
dnnl_dtypeA, dnnl_dtypeB, /*transA=*/false, /*transB=*/false,
brgemm_layout_t::brgemm_row_major, 1.0f, beta, LDA, LDB, LDC, M, N, K,
&stride_info);
&g_cache[g_kernel_id].desc, cpu_isa_t::isa_undef,
brgemm_batch_kind_t::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB,
/*transA=*/false, /*transB=*/false, brgemm_layout_t::brgemm_row_major,
1.0f, beta, LDA, LDB, LDC, M, N, K, &stride_info);
assert(status == dnnl::impl::status::success &&
"Failed to initialize BRGEMM descriptor");

status = brgemm_kernel_create(&kernel, desc);
status = brgemm_kernel_create(&g_cache[g_kernel_id].kernel,
g_cache[g_kernel_id].desc);
assert(status == dnnl::impl::status::success &&
"Failed to JIT BRGEMM kernel");

brgemm_attr_t dnnl_attrs;
brgemm_desc_set_attr(&desc, dnnl_attrs);
brgemm_desc_set_attr(&g_cache[g_kernel_id].desc, dnnl_attrs);

// TODO(haixin): Reuse identical palettes across kernels
std::shared_ptr<char[]> palette_buffer;
if (desc.is_tmm) {
palette_buffer.reset(new char[PALETTE_SIZE]);
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer.get());
if (g_cache[g_kernel_id].desc.is_tmm) {
g_cache[g_kernel_id].palette.reset(new char[PALETTE_SIZE]);
status = brgemm_init_tiles(g_cache[g_kernel_id].desc,
g_cache[g_kernel_id].palette.get());
assert(status == dnnl::impl::status::success &&
"Failed to initialize palette for BRGEMM");
}

write_lock_guard_t g(g_brgemm_lock);
g_cache.push_back(brgemm_cache_info_t{desc, kernel, palette_buffer});
return g_cache.size() - 1;
return g_kernel_id;
}

void dnnl_brgemm_tileconfig(int64_t kernel_idx) {
assert(kernel_idx >= 0 && "Invalid kernel handler");
auto &tl_cache = get_tl_cache();
auto it = tl_cache.find(kernel_idx);
if (it == tl_cache.end()) {
read_lock_guard_t g(g_brgemm_lock);
assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler");
it = tl_cache.insert({kernel_idx, g_cache[kernel_idx]}).first;
std::unique_ptr<read_lock_guard_t> lock_guard;
if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
lock_guard = std::make_unique<read_lock_guard_t>(g_brgemm_lock);
}
brgemm_desc_t &desc = it->second.desc;
char *palette_buffer = it->second.palette.get();

assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
"Invalid kernel handler");
brgemm_desc_t &desc = g_cache[kernel_idx].desc;
if (!desc.is_tmm) {
return;
}

char *palette_buffer = g_cache[kernel_idx].palette.get();
assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel");
amx_tile_configure(palette_buffer);
}
Expand All @@ -146,35 +145,26 @@ void dnnl_brgemm_tilerelease() {
void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset,
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
int num) {
auto &tl_cache = get_tl_cache();
if (tl_cache.find(kernel_idx) == tl_cache.end()) {
read_lock_guard_t g(g_brgemm_lock);
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
"Invalid kernel handler");
auto updated_cache =
tl_cache.insert(std::make_pair(kernel_idx, g_cache[kernel_idx]));
assert(updated_cache.second && "insert into thread local cache");
std::unique_ptr<read_lock_guard_t> lock_guard;
if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
lock_guard = std::make_unique<read_lock_guard_t>(g_brgemm_lock);
}
auto it = tl_cache.find(kernel_idx);
brgemm_kernel_t *kernel = it->second.kernel;
brgemm_desc_t *desc_ptr = &it->second.desc;

assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
"Invalid kernel handler");
brgemm_desc_t &desc = g_cache[kernel_idx].desc;
brgemm_kernel_t *kernel = g_cache[kernel_idx].kernel;
assert(kernel && "Invalid brgemm kernel pointer");
assert(desc_ptr && "Invalid brgemm descriptor pointer");

size_t A_offset_in_bytes =
dnnl::impl::types::data_type_size(desc_ptr->dt_a) * A_offset;
dnnl::impl::types::data_type_size(desc.dt_a) * A_offset;
size_t B_offset_in_bytes =
dnnl::impl::types::data_type_size(desc_ptr->dt_b) * B_offset;
dnnl::impl::types::data_type_size(desc.dt_b) * B_offset;
size_t C_offset_in_bytes =
dnnl::impl::types::data_type_size(desc_ptr->dt_c) * C_offset;

char *A_arith = (char *)A;
char *B_arith = (char *)B;
char *C_arith = (char *)C;
brgemm_kernel_execute(kernel, num, (void *)(A_arith + A_offset_in_bytes),
(void *)(B_arith + B_offset_in_bytes), nullptr,
(void *)(C_arith + C_offset_in_bytes), (void *)scratch);
dnnl::impl::types::data_type_size(desc.dt_c) * C_offset;
char *A_arith = static_cast<char *>(A) + A_offset_in_bytes;
char *B_arith = static_cast<char *>(B) + B_offset_in_bytes;
char *C_arith = static_cast<char *>(C) + C_offset_in_bytes;
brgemm_kernel_execute(kernel, num, A_arith, B_arith, nullptr, C_arith,
scratch);
}
}

Expand Down