Skip to content

Commit 0fcf3e9

Browse files
author
ZhangYan
committed
update
1 parent cb77f94 commit 0fcf3e9

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
5454
static std::shared_mutex g_brgemm_lock;
5555

5656
struct brgemm_cache_info_t {
57-
brgemm_desc_t *desc;
57+
std::shared_ptr<brgemm_desc_t> desc;
5858
brgemm_kernel_t *kernel;
5959
std::shared_ptr<char[]> palette;
6060
};
@@ -70,14 +70,15 @@ static std::vector<brgemm_cache_info_t> &get_tl_cache() {
7070
thread_local std::vector<brgemm_cache_info_t> tl_cache;
7171
return tl_cache;
7272
}
73-
brgemm_desc_t desc;
7473

7574
extern "C" {
7675

7776
int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
7877
int64_t LDB, int64_t LDC, int64_t stride_a,
7978
int64_t stride_b, float beta, int64_t dtypeA,
8079
int64_t dtypeB) {
80+
std::shared_ptr<brgemm_desc_t> desc_ptr = std::make_shared<brgemm_desc_t>();
81+
brgemm_desc_t *desc = desc_ptr.get();
8182
brgemm_kernel_t *kernel;
8283
auto dnnl_dtypeA = static_cast<dnnl_data_type_t>(dtypeA);
8384
auto dnnl_dtypeB = static_cast<dnnl_data_type_t>(dtypeB);
@@ -86,31 +87,32 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
8687
brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
8788

8889
dnnl::impl::status_t status = brgemm_desc_init(
89-
&desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd,
90-
dnnl_dtypeA, dnnl_dtypeB, /*transA=*/false, /*transB=*/false,
90+
desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd, dnnl_dtypeA,
91+
dnnl_dtypeB, /*transA=*/false, /*transB=*/false,
9192
brgemm_layout_t::brgemm_row_major, 1.0f, beta, LDA, LDB, LDC, M, N, K,
9293
&stride_info);
9394
assert(status == dnnl::impl::status::success &&
9495
"Failed to initialize BRGEMM descriptor");
9596

96-
status = brgemm_kernel_create(&kernel, desc);
97+
status = brgemm_kernel_create(&kernel, *desc);
9798
assert(status == dnnl::impl::status::success &&
9899
"Failed to JIT BRGEMM kernel");
99100

100101
brgemm_attr_t dnnl_attrs;
101-
brgemm_desc_set_attr(&desc, dnnl_attrs);
102+
brgemm_desc_set_attr(desc, dnnl_attrs);
102103

103104
// TODO(haixin): Reuse identical palettes across kernels
104105
std::shared_ptr<char[]> palette_buffer;
105-
if (desc.is_tmm) {
106+
if (desc->is_tmm) {
106107
palette_buffer.reset(new char[PALETTE_SIZE]);
107-
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer.get());
108+
dnnl::impl::status_t status =
109+
brgemm_init_tiles(*desc, palette_buffer.get());
108110
assert(status == dnnl::impl::status::success &&
109111
"Failed to initialize palette for BRGEMM");
110112
}
111113

112114
write_lock_guard_t g(g_brgemm_lock);
113-
g_cache.push_back(brgemm_cache_info_t{&desc, kernel, palette_buffer});
115+
g_cache.push_back(brgemm_cache_info_t{desc_ptr, kernel, palette_buffer});
114116
return g_cache.size() - 1;
115117
}
116118

@@ -126,10 +128,10 @@ void dnnl_brgemm_tileconfig(int64_t kernel_idx) {
126128
}
127129
tl_cache[kernel_idx] = g_cache[kernel_idx];
128130
}
129-
brgemm_desc_t &desc = *tl_cache[kernel_idx].desc;
131+
brgemm_desc_t *desc = tl_cache[kernel_idx].desc.get();
130132
char *palette_buffer = tl_cache[kernel_idx].palette.get();
131133

132-
if (!desc.is_tmm) {
134+
if (!desc->is_tmm) {
133135
return;
134136
}
135137

@@ -159,7 +161,7 @@ void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset,
159161
tl_cache[kernel_idx] = g_cache[kernel_idx];
160162
}
161163
brgemm_kernel_t *kernel = tl_cache[kernel_idx].kernel;
162-
brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc;
164+
brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc.get();
163165

164166
assert(kernel && "Invalid brgemm kernel pointer");
165167
assert(desc_ptr && "Invalid brgemm descriptor pointer");

0 commit comments

Comments
 (0)