Skip to content

Commit a62e88e

Browse files
author
Ivy Zhang
authored
add thread local cache for brgemm (#350)
* add thread local cache for brgemm * use static func * fix comment
1 parent ce6d1d3 commit a62e88e

File tree

1 file changed

+47
-33
lines changed

1 file changed

+47
-33
lines changed

lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,24 @@ using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
5353
using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
5454
static std::shared_mutex g_brgemm_lock;
5555

56-
static std::vector<brgemm_desc_t> g_brgemm_desc_list;
57-
static std::vector<brgemm_kernel_t *> g_brgemm_kernel_list;
58-
static std::vector<std::unique_ptr<char[]>> g_brgemm_palette;
56+
struct brgemm_cache_info_t {
57+
brgemm_desc_t desc;
58+
brgemm_kernel_t *kernel;
59+
std::shared_ptr<char[]> palette;
60+
};
61+
62+
static std::vector<brgemm_cache_info_t> g_cache;
5963

6064
// TODO(haixin): use syscall to determine page size?
6165
static constexpr size_t SCRATCH_SIZE = 2 * 4096;
6266
// TODO(haixin): need to use custom thread management for scratch in the future?
6367
static thread_local char scratch[SCRATCH_SIZE] = {0};
6468

69+
static std::unordered_map<int64_t, brgemm_cache_info_t> &get_tl_cache() {
70+
thread_local std::unordered_map<int64_t, brgemm_cache_info_t> tl_cache;
71+
return tl_cache;
72+
}
73+
6574
extern "C" {
6675

6776
int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
@@ -93,33 +102,33 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
93102
brgemm_desc_set_attr(&desc, dnnl_attrs);
94103

95104
// TODO(haixin): Reuse identical palettes across kernels
96-
char *palette_buffer = nullptr;
105+
std::shared_ptr<char[]> palette_buffer;
97106
if (desc.is_tmm) {
98-
palette_buffer = new char[PALETTE_SIZE];
99-
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer);
107+
palette_buffer.reset(new char[PALETTE_SIZE]);
108+
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer.get());
100109
assert(status == dnnl::impl::status::success &&
101110
"Failed to initialize palette for BRGEMM");
102111
}
103112

104113
write_lock_guard_t g(g_brgemm_lock);
105-
g_brgemm_desc_list.push_back(desc);
106-
g_brgemm_kernel_list.push_back(kernel);
107-
g_brgemm_palette.emplace_back(palette_buffer);
108-
109-
return g_brgemm_desc_list.size() - 1;
114+
g_cache.push_back(brgemm_cache_info_t{desc, kernel, palette_buffer});
115+
return g_cache.size() - 1;
110116
}
111117

112118
void dnnl_brgemm_tileconfig(int64_t kernel_idx) {
113-
char *palette_buffer = nullptr;
114-
{
119+
assert(kernel_idx >= 0 && "Invalid kernel handler");
120+
auto &tl_cache = get_tl_cache();
121+
auto it = tl_cache.find(kernel_idx);
122+
if (it == tl_cache.end()) {
115123
read_lock_guard_t g(g_brgemm_lock);
116-
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() &&
117-
"Invalid kernel handler");
118-
brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx];
119-
if (!desc.is_tmm) {
120-
return;
121-
}
122-
palette_buffer = g_brgemm_palette[kernel_idx].get();
124+
assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler");
125+
it = tl_cache.insert({kernel_idx, g_cache[kernel_idx]}).first;
126+
}
127+
brgemm_desc_t &desc = it->second.desc;
128+
char *palette_buffer = it->second.palette.get();
129+
130+
if (!desc.is_tmm) {
131+
return;
123132
}
124133

125134
assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel");
@@ -137,24 +146,29 @@ void dnnl_brgemm_tilerelease() {
137146
void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset,
138147
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
139148
int num) {
140-
brgemm_kernel_t *kernel = nullptr;
141-
size_t A_offset_in_bytes;
142-
size_t B_offset_in_bytes;
143-
size_t C_offset_in_bytes;
144-
{
149+
auto &tl_cache = get_tl_cache();
150+
if (tl_cache.find(kernel_idx) == tl_cache.end()) {
145151
read_lock_guard_t g(g_brgemm_lock);
146-
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_brgemm_desc_list.size() &&
152+
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
147153
"Invalid kernel handler");
148-
149-
brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx];
150-
kernel = g_brgemm_kernel_list[kernel_idx];
151-
152-
A_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_a) * A_offset;
153-
B_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_b) * B_offset;
154-
C_offset_in_bytes = dnnl::impl::types::data_type_size(desc.dt_c) * C_offset;
154+
auto updated_cache =
155+
tl_cache.insert(std::make_pair(kernel_idx, g_cache[kernel_idx]));
156+
assert(updated_cache.second && "insert into thread local cache");
155157
}
158+
auto it = tl_cache.find(kernel_idx);
159+
brgemm_kernel_t *kernel = it->second.kernel;
160+
brgemm_desc_t *desc_ptr = &it->second.desc;
156161

157162
assert(kernel && "Invalid brgemm kernel pointer");
163+
assert(desc_ptr && "Invalid brgemm descriptor pointer");
164+
165+
size_t A_offset_in_bytes =
166+
dnnl::impl::types::data_type_size(desc_ptr->dt_a) * A_offset;
167+
size_t B_offset_in_bytes =
168+
dnnl::impl::types::data_type_size(desc_ptr->dt_b) * B_offset;
169+
size_t C_offset_in_bytes =
170+
dnnl::impl::types::data_type_size(desc_ptr->dt_c) * C_offset;
171+
158172
char *A_arith = (char *)A;
159173
char *B_arith = (char *)B;
160174
char *C_arith = (char *)C;

0 commit comments

Comments
 (0)