Skip to content

Commit d794dc7

Browse files
author
Ivy Zhang
authored
optimize thread local cache for brgemm (#353)
* add thread local cache for brgemm * encapsulate thread local cache * best perf * use static func * fix comment * use vector * update * use static vector with size = 1024 * fix comment * fix comment
1 parent 993e095 commit d794dc7

File tree

1 file changed

+49
-59
lines changed

1 file changed

+49
-59
lines changed

lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp

Lines changed: 49 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ __attribute__((weak)) void print_verbose_header() {}
4848
} // namespace dnnl
4949

5050
static constexpr int PALETTE_SIZE = 64;
51+
static constexpr int DEFAULT_KERNEL_SIZE = 1024;
52+
static constexpr int MAX_KERNEL_SIZE = 2048;
5153

5254
using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
5355
using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
@@ -56,81 +58,78 @@ static std::shared_mutex g_brgemm_lock;
5658
struct brgemm_cache_info_t {
5759
brgemm_desc_t desc;
5860
brgemm_kernel_t *kernel;
59-
std::shared_ptr<char[]> palette;
61+
std::unique_ptr<char[]> palette;
6062
};
6163

62-
static std::vector<brgemm_cache_info_t> g_cache;
64+
static std::vector<brgemm_cache_info_t> g_cache(DEFAULT_KERNEL_SIZE);
65+
static int64_t g_kernel_id = -1;
6366

6467
// TODO(haixin): use syscall to determine page size?
6568
static constexpr size_t SCRATCH_SIZE = 2 * 4096;
6669
// TODO(haixin): need to use custom thread management for scratch in the future?
6770
static thread_local char scratch[SCRATCH_SIZE] = {0};
6871

69-
static std::unordered_map<int64_t, brgemm_cache_info_t> &get_tl_cache() {
70-
thread_local std::unordered_map<int64_t, brgemm_cache_info_t> tl_cache;
71-
return tl_cache;
72-
}
73-
7472
extern "C" {
7573

7674
int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
7775
int64_t LDB, int64_t LDC, int64_t stride_a,
7876
int64_t stride_b, float beta, int64_t dtypeA,
7977
int64_t dtypeB) {
80-
brgemm_desc_t desc;
81-
brgemm_kernel_t *kernel;
82-
8378
auto dnnl_dtypeA = static_cast<dnnl_data_type_t>(dtypeA);
8479
auto dnnl_dtypeB = static_cast<dnnl_data_type_t>(dtypeB);
8580
int64_t dtypeA_size = dnnl::impl::types::data_type_size(dnnl_dtypeA);
8681
int64_t dtypeB_size = dnnl::impl::types::data_type_size(dnnl_dtypeB);
8782
brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
8883

84+
write_lock_guard_t g(g_brgemm_lock);
85+
g_kernel_id++;
86+
assert(g_kernel_id < MAX_KERNEL_SIZE &&
87+
"Too many brgemm kernels are created");
88+
if (g_kernel_id >= DEFAULT_KERNEL_SIZE) {
89+
if (g_kernel_id >= (int64_t)g_cache.size()) {
90+
g_cache.resize(g_kernel_id + 1);
91+
}
92+
}
93+
8994
dnnl::impl::status_t status = brgemm_desc_init(
90-
&desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd,
91-
dnnl_dtypeA, dnnl_dtypeB, /*transA=*/false, /*transB=*/false,
92-
brgemm_layout_t::brgemm_row_major, 1.0f, beta, LDA, LDB, LDC, M, N, K,
93-
&stride_info);
95+
&g_cache[g_kernel_id].desc, cpu_isa_t::isa_undef,
96+
brgemm_batch_kind_t::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB,
97+
/*transA=*/false, /*transB=*/false, brgemm_layout_t::brgemm_row_major,
98+
1.0f, beta, LDA, LDB, LDC, M, N, K, &stride_info);
9499
assert(status == dnnl::impl::status::success &&
95100
"Failed to initialize BRGEMM descriptor");
96101

97-
status = brgemm_kernel_create(&kernel, desc);
102+
status = brgemm_kernel_create(&g_cache[g_kernel_id].kernel,
103+
g_cache[g_kernel_id].desc);
98104
assert(status == dnnl::impl::status::success &&
99105
"Failed to JIT BRGEMM kernel");
100106

101107
brgemm_attr_t dnnl_attrs;
102-
brgemm_desc_set_attr(&desc, dnnl_attrs);
108+
brgemm_desc_set_attr(&g_cache[g_kernel_id].desc, dnnl_attrs);
103109

104-
// TODO(haixin): Reuse identical palettes across kernels
105-
std::shared_ptr<char[]> palette_buffer;
106-
if (desc.is_tmm) {
107-
palette_buffer.reset(new char[PALETTE_SIZE]);
108-
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer.get());
110+
if (g_cache[g_kernel_id].desc.is_tmm) {
111+
g_cache[g_kernel_id].palette.reset(new char[PALETTE_SIZE]);
112+
status = brgemm_init_tiles(g_cache[g_kernel_id].desc,
113+
g_cache[g_kernel_id].palette.get());
109114
assert(status == dnnl::impl::status::success &&
110115
"Failed to initialize palette for BRGEMM");
111116
}
112117

113-
write_lock_guard_t g(g_brgemm_lock);
114-
g_cache.push_back(brgemm_cache_info_t{desc, kernel, palette_buffer});
115-
return g_cache.size() - 1;
118+
return g_kernel_id;
116119
}
117120

118121
void dnnl_brgemm_tileconfig(int64_t kernel_idx) {
119-
assert(kernel_idx >= 0 && "Invalid kernel handler");
120-
auto &tl_cache = get_tl_cache();
121-
auto it = tl_cache.find(kernel_idx);
122-
if (it == tl_cache.end()) {
123-
read_lock_guard_t g(g_brgemm_lock);
124-
assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler");
125-
it = tl_cache.insert({kernel_idx, g_cache[kernel_idx]}).first;
122+
std::unique_ptr<read_lock_guard_t> lock_guard;
123+
if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
124+
lock_guard = std::make_unique<read_lock_guard_t>(g_brgemm_lock);
126125
}
127-
brgemm_desc_t &desc = it->second.desc;
128-
char *palette_buffer = it->second.palette.get();
129-
126+
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
127+
"Invalid kernel handler");
128+
brgemm_desc_t &desc = g_cache[kernel_idx].desc;
130129
if (!desc.is_tmm) {
131130
return;
132131
}
133-
132+
char *palette_buffer = g_cache[kernel_idx].palette.get();
134133
assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel");
135134
amx_tile_configure(palette_buffer);
136135
}
@@ -146,35 +145,26 @@ void dnnl_brgemm_tilerelease() {
146145
void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset,
147146
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
148147
int num) {
149-
auto &tl_cache = get_tl_cache();
150-
if (tl_cache.find(kernel_idx) == tl_cache.end()) {
151-
read_lock_guard_t g(g_brgemm_lock);
152-
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
153-
"Invalid kernel handler");
154-
auto updated_cache =
155-
tl_cache.insert(std::make_pair(kernel_idx, g_cache[kernel_idx]));
156-
assert(updated_cache.second && "insert into thread local cache");
148+
std::unique_ptr<read_lock_guard_t> lock_guard;
149+
if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
150+
lock_guard = std::make_unique<read_lock_guard_t>(g_brgemm_lock);
157151
}
158-
auto it = tl_cache.find(kernel_idx);
159-
brgemm_kernel_t *kernel = it->second.kernel;
160-
brgemm_desc_t *desc_ptr = &it->second.desc;
161-
152+
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
153+
"Invalid kernel handler");
154+
brgemm_desc_t &desc = g_cache[kernel_idx].desc;
155+
brgemm_kernel_t *kernel = g_cache[kernel_idx].kernel;
162156
assert(kernel && "Invalid brgemm kernel pointer");
163-
assert(desc_ptr && "Invalid brgemm descriptor pointer");
164-
165157
size_t A_offset_in_bytes =
166-
dnnl::impl::types::data_type_size(desc_ptr->dt_a) * A_offset;
158+
dnnl::impl::types::data_type_size(desc.dt_a) * A_offset;
167159
size_t B_offset_in_bytes =
168-
dnnl::impl::types::data_type_size(desc_ptr->dt_b) * B_offset;
160+
dnnl::impl::types::data_type_size(desc.dt_b) * B_offset;
169161
size_t C_offset_in_bytes =
170-
dnnl::impl::types::data_type_size(desc_ptr->dt_c) * C_offset;
171-
172-
char *A_arith = (char *)A;
173-
char *B_arith = (char *)B;
174-
char *C_arith = (char *)C;
175-
brgemm_kernel_execute(kernel, num, (void *)(A_arith + A_offset_in_bytes),
176-
(void *)(B_arith + B_offset_in_bytes), nullptr,
177-
(void *)(C_arith + C_offset_in_bytes), (void *)scratch);
162+
dnnl::impl::types::data_type_size(desc.dt_c) * C_offset;
163+
char *A_arith = static_cast<char *>(A) + A_offset_in_bytes;
164+
char *B_arith = static_cast<char *>(B) + B_offset_in_bytes;
165+
char *C_arith = static_cast<char *>(C) + C_offset_in_bytes;
166+
brgemm_kernel_execute(kernel, num, A_arith, B_arith, nullptr, C_arith,
167+
scratch);
178168
}
179169
}
180170

0 commit comments

Comments
 (0)