Skip to content

Commit 5c04e70

Browse files
author
ZhangYan
committed
use static vector with size = 1024
1 parent 93f4031 commit 5c04e70

File tree

1 file changed

+50
-66
lines changed

1 file changed

+50
-66
lines changed

lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp

Lines changed: 50 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -48,93 +48,86 @@ __attribute__((weak)) void print_verbose_header() {}
4848
} // namespace dnnl
4949

5050
static constexpr int PALETTE_SIZE = 64;
51+
static constexpr int DEFAULT_KERNEL_SIZE = 1024;
5152

5253
using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
5354
using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
5455
static std::shared_mutex g_brgemm_lock;
5556

5657
struct brgemm_cache_info_t {
57-
std::shared_ptr<brgemm_desc_t> desc;
58+
brgemm_desc_t desc;
5859
brgemm_kernel_t *kernel;
59-
std::shared_ptr<char[]> palette;
60+
std::unique_ptr<char[]> palette;
6061
};
6162

62-
static std::vector<brgemm_cache_info_t> g_cache;
63+
static std::vector<brgemm_cache_info_t> g_cache(DEFAULT_KERNEL_SIZE);
64+
static int64_t kernel_id = -1;
6365

6466
// TODO(haixin): use syscall to determine page size?
6567
static constexpr size_t SCRATCH_SIZE = 2 * 4096;
6668
// TODO(haixin): need to use custom thread management for scratch in the future?
6769
static thread_local char scratch[SCRATCH_SIZE] = {0};
6870

69-
static std::vector<brgemm_cache_info_t> &get_tl_cache() {
70-
thread_local std::vector<brgemm_cache_info_t> tl_cache;
71-
return tl_cache;
72-
}
73-
7471
extern "C" {
7572

7673
int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
7774
int64_t LDB, int64_t LDC, int64_t stride_a,
7875
int64_t stride_b, float beta, int64_t dtypeA,
7976
int64_t dtypeB) {
80-
std::shared_ptr<brgemm_desc_t> desc_ptr = std::make_shared<brgemm_desc_t>();
81-
brgemm_desc_t *desc = desc_ptr.get();
82-
brgemm_kernel_t *kernel;
8377
auto dnnl_dtypeA = static_cast<dnnl_data_type_t>(dtypeA);
8478
auto dnnl_dtypeB = static_cast<dnnl_data_type_t>(dtypeB);
8579
int64_t dtypeA_size = dnnl::impl::types::data_type_size(dnnl_dtypeA);
8680
int64_t dtypeB_size = dnnl::impl::types::data_type_size(dnnl_dtypeB);
8781
brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
8882

83+
write_lock_guard_t g(g_brgemm_lock);
84+
kernel_id++;
85+
86+
if (kernel_id >= DEFAULT_KERNEL_SIZE) {
87+
if (kernel_id >= (int64_t)g_cache.size()) {
88+
g_cache.resize(kernel_id + 1);
89+
}
90+
}
91+
8992
dnnl::impl::status_t status = brgemm_desc_init(
90-
desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd, dnnl_dtypeA,
91-
dnnl_dtypeB, /*transA=*/false, /*transB=*/false,
92-
brgemm_layout_t::brgemm_row_major, 1.0f, beta, LDA, LDB, LDC, M, N, K,
93-
&stride_info);
93+
&g_cache[kernel_id].desc, cpu_isa_t::isa_undef,
94+
brgemm_batch_kind_t::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB,
95+
/*transA=*/false, /*transB=*/false, brgemm_layout_t::brgemm_row_major,
96+
1.0f, beta, LDA, LDB, LDC, M, N, K, &stride_info);
9497
assert(status == dnnl::impl::status::success &&
9598
"Failed to initialize BRGEMM descriptor");
9699

97-
status = brgemm_kernel_create(&kernel, *desc);
100+
status =
101+
brgemm_kernel_create(&g_cache[kernel_id].kernel, g_cache[kernel_id].desc);
98102
assert(status == dnnl::impl::status::success &&
99103
"Failed to JIT BRGEMM kernel");
100104

101105
brgemm_attr_t dnnl_attrs;
102-
brgemm_desc_set_attr(desc, dnnl_attrs);
103-
104-
// TODO(haixin): Reuse identical palettes across kernels
105-
std::shared_ptr<char[]> palette_buffer;
106-
if (desc->is_tmm) {
107-
palette_buffer.reset(new char[PALETTE_SIZE]);
108-
dnnl::impl::status_t status =
109-
brgemm_init_tiles(*desc, palette_buffer.get());
106+
brgemm_desc_set_attr(&g_cache[kernel_id].desc, dnnl_attrs);
107+
108+
if (g_cache[kernel_id].desc.is_tmm) {
109+
g_cache[kernel_id].palette.reset(new char[PALETTE_SIZE]);
110+
status = brgemm_init_tiles(g_cache[kernel_id].desc,
111+
g_cache[kernel_id].palette.get());
110112
assert(status == dnnl::impl::status::success &&
111113
"Failed to initialize palette for BRGEMM");
112114
}
113115

114-
write_lock_guard_t g(g_brgemm_lock);
115-
g_cache.push_back(brgemm_cache_info_t{desc_ptr, kernel, palette_buffer});
116-
return g_cache.size() - 1;
116+
return kernel_id;
117117
}
118118

119119
void dnnl_brgemm_tileconfig(int64_t kernel_idx) {
120-
assert(kernel_idx >= 0 && "Invalid kernel handler");
121-
auto &tl_cache = get_tl_cache();
122-
if (kernel_idx >= (int64_t)tl_cache.size() ||
123-
tl_cache[kernel_idx].kernel == nullptr) {
124-
read_lock_guard_t g(g_brgemm_lock);
125-
assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler");
126-
if (kernel_idx >= (int64_t)tl_cache.size()) {
127-
tl_cache.resize(kernel_idx + 1);
128-
}
129-
tl_cache[kernel_idx] = g_cache[kernel_idx];
120+
std::unique_ptr<read_lock_guard_t> lock_guard;
121+
if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
122+
lock_guard = std::make_unique<read_lock_guard_t>(g_brgemm_lock);
130123
}
131-
brgemm_desc_t *desc = tl_cache[kernel_idx].desc.get();
132-
char *palette_buffer = tl_cache[kernel_idx].palette.get();
133-
134-
if (!desc->is_tmm) {
124+
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
125+
"Invalid kernel handler");
126+
brgemm_desc_t &desc = g_cache[kernel_idx].desc;
127+
if (!desc.is_tmm) {
135128
return;
136129
}
137-
130+
char *palette_buffer = g_cache[kernel_idx].palette.get();
138131
assert(palette_buffer != nullptr && "Invalid palette for BRGEMM kernel");
139132
amx_tile_configure(palette_buffer);
140133
}
@@ -150,35 +143,26 @@ void dnnl_brgemm_tilerelease() {
150143
void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset,
151144
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
152145
int num) {
153-
auto &tl_cache = get_tl_cache();
154-
if (kernel_idx >= (int64_t)tl_cache.size() ||
155-
tl_cache[kernel_idx].kernel == nullptr) {
156-
read_lock_guard_t g(g_brgemm_lock);
157-
assert(kernel_idx < (int64_t)g_cache.size() && "Invalid kernel handler");
158-
if (kernel_idx >= (int64_t)tl_cache.size()) {
159-
tl_cache.resize(kernel_idx + 1);
160-
}
161-
tl_cache[kernel_idx] = g_cache[kernel_idx];
146+
std::unique_ptr<read_lock_guard_t> lock_guard;
147+
if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
148+
lock_guard = std::make_unique<read_lock_guard_t>(g_brgemm_lock);
162149
}
163-
brgemm_kernel_t *kernel = tl_cache[kernel_idx].kernel;
164-
brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc.get();
165-
150+
assert(kernel_idx >= 0 && kernel_idx < (int64_t)g_cache.size() &&
151+
"Invalid kernel handler");
152+
brgemm_desc_t &desc = g_cache[kernel_idx].desc;
153+
brgemm_kernel_t *kernel = g_cache[kernel_idx].kernel;
166154
assert(kernel && "Invalid brgemm kernel pointer");
167-
assert(desc_ptr && "Invalid brgemm descriptor pointer");
168-
169155
size_t A_offset_in_bytes =
170-
dnnl::impl::types::data_type_size(desc_ptr->dt_a) * A_offset;
156+
dnnl::impl::types::data_type_size(desc.dt_a) * A_offset;
171157
size_t B_offset_in_bytes =
172-
dnnl::impl::types::data_type_size(desc_ptr->dt_b) * B_offset;
158+
dnnl::impl::types::data_type_size(desc.dt_b) * B_offset;
173159
size_t C_offset_in_bytes =
174-
dnnl::impl::types::data_type_size(desc_ptr->dt_c) * C_offset;
175-
176-
char *A_arith = (char *)A;
177-
char *B_arith = (char *)B;
178-
char *C_arith = (char *)C;
179-
brgemm_kernel_execute(kernel, num, (void *)(A_arith + A_offset_in_bytes),
180-
(void *)(B_arith + B_offset_in_bytes), nullptr,
181-
(void *)(C_arith + C_offset_in_bytes), (void *)scratch);
160+
dnnl::impl::types::data_type_size(desc.dt_c) * C_offset;
161+
char *A_arith = static_cast<char *>(A) + A_offset_in_bytes;
162+
char *B_arith = static_cast<char *>(B) + B_offset_in_bytes;
163+
char *C_arith = static_cast<char *>(C) + C_offset_in_bytes;
164+
brgemm_kernel_execute(kernel, num, A_arith, B_arith, nullptr, C_arith,
165+
scratch);
182166
}
183167
}
184168

0 commit comments

Comments
 (0)