Skip to content

Commit 76346ac

Browse files
author
ZhangYan
committed
update
1 parent cb77f94 commit 76346ac

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

lib/gc/ExecutionEngine/CPURuntime/Microkernel/BrgemmOnednn.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,15 @@ static std::vector<brgemm_cache_info_t> &get_tl_cache() {
7070
thread_local std::vector<brgemm_cache_info_t> tl_cache;
7171
return tl_cache;
7272
}
73-
brgemm_desc_t desc;
7473

7574
extern "C" {
7675

7776
int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
7877
int64_t LDB, int64_t LDC, int64_t stride_a,
7978
int64_t stride_b, float beta, int64_t dtypeA,
8079
int64_t dtypeB) {
80+
std::shared_ptr<brgemm_desc_t> desc_ptr = std::make_shared<brgemm_desc_t>();
81+
brgemm_desc_t *desc = desc_ptr.get();
8182
brgemm_kernel_t *kernel;
8283
auto dnnl_dtypeA = static_cast<dnnl_data_type_t>(dtypeA);
8384
auto dnnl_dtypeB = static_cast<dnnl_data_type_t>(dtypeB);
@@ -86,31 +87,32 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
8687
brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
8788

8889
dnnl::impl::status_t status = brgemm_desc_init(
89-
&desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd,
90-
dnnl_dtypeA, dnnl_dtypeB, /*transA=*/false, /*transB=*/false,
90+
desc, cpu_isa_t::isa_undef, brgemm_batch_kind_t::brgemm_strd, dnnl_dtypeA,
91+
dnnl_dtypeB, /*transA=*/false, /*transB=*/false,
9192
brgemm_layout_t::brgemm_row_major, 1.0f, beta, LDA, LDB, LDC, M, N, K,
9293
&stride_info);
9394
assert(status == dnnl::impl::status::success &&
9495
"Failed to initialize BRGEMM descriptor");
9596

96-
status = brgemm_kernel_create(&kernel, desc);
97+
status = brgemm_kernel_create(&kernel, *desc);
9798
assert(status == dnnl::impl::status::success &&
9899
"Failed to JIT BRGEMM kernel");
99100

100101
brgemm_attr_t dnnl_attrs;
101-
brgemm_desc_set_attr(&desc, dnnl_attrs);
102+
brgemm_desc_set_attr(desc, dnnl_attrs);
102103

103104
// TODO(haixin): Reuse identical palettes across kernels
104105
std::shared_ptr<char[]> palette_buffer;
105-
if (desc.is_tmm) {
106+
if (desc->is_tmm) {
106107
palette_buffer.reset(new char[PALETTE_SIZE]);
107-
dnnl::impl::status_t status = brgemm_init_tiles(desc, palette_buffer.get());
108+
dnnl::impl::status_t status =
109+
brgemm_init_tiles(*desc, palette_buffer.get());
108110
assert(status == dnnl::impl::status::success &&
109111
"Failed to initialize palette for BRGEMM");
110112
}
111113

112114
write_lock_guard_t g(g_brgemm_lock);
113-
g_cache.push_back(brgemm_cache_info_t{&desc, kernel, palette_buffer});
115+
g_cache.push_back(brgemm_cache_info_t{desc, kernel, palette_buffer});
114116
return g_cache.size() - 1;
115117
}
116118

0 commit comments

Comments
 (0)