-
Notifications
You must be signed in to change notification settings - Fork 17
add thread local cache for brgemm #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,9 +53,14 @@ using read_lock_guard_t = std::shared_lock<std::shared_mutex>; | |
using write_lock_guard_t = std::unique_lock<std::shared_mutex>; | ||
static std::shared_mutex g_brgemm_lock; | ||
|
||
static std::vector<brgemm_desc_t> g_brgemm_desc_list; | ||
static std::vector<brgemm_kernel_t *> g_brgemm_kernel_list; | ||
static std::vector<std::unique_ptr<char[]>> g_brgemm_palette; | ||
struct brgemm_cache_info_t { | ||
brgemm_desc_t desc; | ||
brgemm_kernel_t *kernel; | ||
std::shared_ptr<char[]> palette; | ||
}; | ||
|
||
static std::vector<brgemm_cache_info_t> g_cache; | ||
static thread_local std::unordered_map<int64_t, brgemm_cache_info_t> tl_cache; | ||
|
||
// TODO(haixin): use syscall to determine page size? | ||
static constexpr size_t SCRATCH_SIZE = 2 * 4096; | ||
|
@@ -93,33 +98,36 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA, | |
brgemm_desc_set_attr(&desc, dnnl_attrs); | ||
|
||
// TODO(haixin): Reuse identical palettes across kernels | ||
char *palette_buffer = nullptr; | ||
std::shared_ptr<char[]> palette_buffer(new char[PALETTE_SIZE], | ||
std::default_delete<char[]>()); | ||
if (desc.is_tmm) { | ||
palette_buffer = new char[PALETTE_SIZE]; | ||
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer); | ||
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer.get()); | ||
assert(status == dnnl::impl::status::success && | ||
"Failed to initialize palette for BRGEMM"); | ||
} | ||
|
||
write_lock_guard_t g(g_brgemm_lock); | ||
g_brgemm_desc_list.push_back(desc); | ||
g_brgemm_kernel_list.push_back(kernel); | ||
g_brgemm_palette.emplace_back(palette_buffer); | ||
|
||
return g_brgemm_desc_list.size() - 1; | ||
g_cache.push_back(brgemm_cache_info_t{desc, kernel, palette_buffer}); | ||
return g_cache.size() - 1; | ||
} | ||
|
||
void dnnl_brgemm_tileconfig(int64_t kernel_idx) { | ||
char *palette_buffer = nullptr; | ||
{ | ||
assert(kernel_idx >= 0 && "Invalid kernel handler"); | ||
auto it = tl_cache.find(kernel_idx); | ||
if (it == tl_cache.end()) { | ||
read_lock_guard_t g(g_brgemm_lock); | ||
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() && | ||
"Invalid kernel handler"); | ||
brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx]; | ||
if (!desc.is_tmm) { | ||
return; | ||
} | ||
palette_buffer = g_brgemm_palette[kernel_idx].get(); | ||
assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler"); | ||
|
||
brgemm_cache_info_t tl_content = {g_cache[kernel_idx].desc, | ||
g_cache[kernel_idx].kernel, | ||
g_cache[kernel_idx].palette}; | ||
it = tl_cache.insert({kernel_idx, tl_content}).first; | ||
} | ||
brgemm_desc_t &desc = it->second.desc; | ||
char *palette_buffer = it->second.palette.get(); | ||
|
||
if (!desc.is_tmm) { | ||
return; | ||
} | ||
|
||
assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel"); | ||
|
@@ -137,24 +145,32 @@ void dnnl_brgemm_tilerelease() { | |
void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset, | ||
void *B, uint64_t B_offset, void *C, uint64_t C_offset, | ||
int num) { | ||
brgemm_kernel_t *kernel = nullptr; | ||
size_t A_offset_in_bytes; | ||
size_t B_offset_in_bytes; | ||
size_t C_offset_in_bytes; | ||
{ | ||
if (tl_cache.find(kernel_idx) == tl_cache.end()) { | ||
read_lock_guard_t g(g_brgemm_lock); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since it's thread local, do we still need this lock? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when the target brgemm kernel is not found in thread_local cache, we still need to lock the global cache to get the target brgemm. |
||
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() && | ||
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() && | ||
"Invalid kernel handler"); | ||
|
||
brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx]; | ||
kernel = g_brgemm_kernel_list[kernel_idx]; | ||
|
||
A_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_a) * A_offset; | ||
B_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_b) * B_offset; | ||
C_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_c) * C_offset; | ||
brgemm_cache_info_t tl_content = {g_cache[kernel_idx].desc, | ||
g_cache[kernel_idx].kernel, | ||
g_cache[kernel_idx].palette}; | ||
auto updated_cache = | ||
tl_cache.insert(std::make_pair(kernel_idx, tl_content)); | ||
assert(updated_cache.second && "insert into thread local cache"); | ||
} | ||
auto it = tl_cache.find(kernel_idx); | ||
brgemm_kernel_t *kernel = it->second.kernel; | ||
brgemm_desc_t *desc_ptr = &it->second.desc; | ||
|
||
assert(kernel && "Invalid brgemm kernel pointer"); | ||
assert(desc_ptr && "Invalid brgemm descriptor pointer"); | ||
|
||
size_t A_offset_in_bytes = | ||
dnnl::impl::types::data_type_size(desc_ptr->dt_a) * A_offset; | ||
size_t B_offset_in_bytes = | ||
dnnl::impl::types::data_type_size(desc_ptr->dt_b) * B_offset; | ||
size_t C_offset_in_bytes = | ||
dnnl::impl::types::data_type_size(desc_ptr->dt_c) * C_offset; | ||
|
||
char *A_arith = (char *)A; | ||
char *B_arith = (char *)B; | ||
char *C_arith = (char *)C; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We only need to
new
palette buffer whendesc.is_tmm
istrue
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.