@@ -70,14 +70,15 @@ static std::vector<brgemm_cache_info_t> &get_tl_cache() {
70
70
thread_local std::vector<brgemm_cache_info_t > tl_cache;
71
71
return tl_cache;
72
72
}
73
- brgemm_desc_t desc;
74
73
75
74
extern " C" {
76
75
77
76
int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
78
77
int64_t LDB, int64_t LDC, int64_t stride_a,
79
78
int64_t stride_b, float beta, int64_t dtypeA,
80
79
int64_t dtypeB) {
80
+ std::shared_ptr<brgemm_desc_t > desc_ptr = std::make_shared<brgemm_desc_t >();
81
+ brgemm_desc_t *desc = desc_ptr.get ();
81
82
brgemm_kernel_t *kernel;
82
83
auto dnnl_dtypeA = static_cast <dnnl_data_type_t >(dtypeA);
83
84
auto dnnl_dtypeB = static_cast <dnnl_data_type_t >(dtypeB);
@@ -86,31 +87,32 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
86
87
brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
87
88
88
89
dnnl::impl::status_t status = brgemm_desc_init (
89
- & desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd,
90
- dnnl_dtypeA, dnnl_dtypeB, /* transA=*/ false , /* transB=*/ false ,
90
+ desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA ,
91
+ dnnl_dtypeB, /* transA=*/ false , /* transB=*/ false ,
91
92
brgemm_layout_t ::brgemm_row_major, 1 .0f , beta, LDA, LDB, LDC, M, N, K,
92
93
&stride_info);
93
94
assert (status == dnnl::impl::status::success &&
94
95
" Failed to initialize BRGEMM descriptor" );
95
96
96
- status = brgemm_kernel_create (&kernel, desc);
97
+ status = brgemm_kernel_create (&kernel, * desc);
97
98
assert (status == dnnl::impl::status::success &&
98
99
" Failed to JIT BRGEMM kernel" );
99
100
100
101
brgemm_attr_t dnnl_attrs;
101
- brgemm_desc_set_attr (& desc, dnnl_attrs);
102
+ brgemm_desc_set_attr (desc, dnnl_attrs);
102
103
103
104
// TODO(haixin): Reuse identical palettes across kernels
104
105
std::shared_ptr<char []> palette_buffer;
105
- if (desc. is_tmm ) {
106
+ if (desc-> is_tmm ) {
106
107
palette_buffer.reset (new char [PALETTE_SIZE]);
107
- dnnl::impl::status_t status = brgemm_init_tiles (desc, palette_buffer.get ());
108
+ dnnl::impl::status_t status =
109
+ brgemm_init_tiles (*desc, palette_buffer.get ());
108
110
assert (status == dnnl::impl::status::success &&
109
111
" Failed to initialize palette for BRGEMM" );
110
112
}
111
113
112
114
write_lock_guard_t g (g_brgemm_lock);
113
- g_cache.push_back (brgemm_cache_info_t {& desc, kernel, palette_buffer});
115
+ g_cache.push_back (brgemm_cache_info_t {desc, kernel, palette_buffer});
114
116
return g_cache.size () - 1 ;
115
117
}
116
118
0 commit comments