@@ -54,7 +54,7 @@ using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
54
54
static std::shared_mutex g_brgemm_lock;
55
55
56
56
struct brgemm_cache_info_t {
57
- brgemm_desc_t * desc;
57
+ std::shared_ptr< brgemm_desc_t > desc;
58
58
brgemm_kernel_t *kernel;
59
59
std::shared_ptr<char []> palette;
60
60
};
@@ -70,14 +70,15 @@ static std::vector<brgemm_cache_info_t> &get_tl_cache() {
70
70
thread_local std::vector<brgemm_cache_info_t > tl_cache;
71
71
return tl_cache;
72
72
}
73
- brgemm_desc_t desc;
74
73
75
74
extern " C" {
76
75
77
76
int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
78
77
int64_t LDB, int64_t LDC, int64_t stride_a,
79
78
int64_t stride_b, float beta, int64_t dtypeA,
80
79
int64_t dtypeB) {
80
+ std::shared_ptr<brgemm_desc_t > desc_ptr = std::make_shared<brgemm_desc_t >();
81
+ brgemm_desc_t *desc = desc_ptr.get ();
81
82
brgemm_kernel_t *kernel;
82
83
auto dnnl_dtypeA = static_cast <dnnl_data_type_t >(dtypeA);
83
84
auto dnnl_dtypeB = static_cast <dnnl_data_type_t >(dtypeB);
@@ -86,31 +87,32 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
86
87
brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
87
88
88
89
dnnl::impl::status_t status = brgemm_desc_init (
89
- & desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd,
90
- dnnl_dtypeA, dnnl_dtypeB, /* transA=*/ false , /* transB=*/ false ,
90
+ desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA ,
91
+ dnnl_dtypeB, /* transA=*/ false , /* transB=*/ false ,
91
92
brgemm_layout_t ::brgemm_row_major, 1 .0f , beta, LDA, LDB, LDC, M, N, K,
92
93
&stride_info);
93
94
assert (status == dnnl::impl::status::success &&
94
95
" Failed to initialize BRGEMM descriptor" );
95
96
96
- status = brgemm_kernel_create (&kernel, desc);
97
+ status = brgemm_kernel_create (&kernel, * desc);
97
98
assert (status == dnnl::impl::status::success &&
98
99
" Failed to JIT BRGEMM kernel" );
99
100
100
101
brgemm_attr_t dnnl_attrs;
101
- brgemm_desc_set_attr (& desc, dnnl_attrs);
102
+ brgemm_desc_set_attr (desc, dnnl_attrs);
102
103
103
104
// TODO(haixin): Reuse identical palettes across kernels
104
105
std::shared_ptr<char []> palette_buffer;
105
- if (desc. is_tmm ) {
106
+ if (desc-> is_tmm ) {
106
107
palette_buffer.reset (new char [PALETTE_SIZE]);
107
- dnnl::impl::status_t status = brgemm_init_tiles (desc, palette_buffer.get ());
108
+ dnnl::impl::status_t status =
109
+ brgemm_init_tiles (*desc, palette_buffer.get ());
108
110
assert (status == dnnl::impl::status::success &&
109
111
" Failed to initialize palette for BRGEMM" );
110
112
}
111
113
112
114
write_lock_guard_t g (g_brgemm_lock);
113
- g_cache.push_back (brgemm_cache_info_t {&desc , kernel, palette_buffer});
115
+ g_cache.push_back (brgemm_cache_info_t {desc_ptr , kernel, palette_buffer});
114
116
return g_cache.size () - 1 ;
115
117
}
116
118
@@ -126,10 +128,10 @@ void dnnl_brgemm_tileconfig(int64_t kernel_idx) {
126
128
}
127
129
tl_cache[kernel_idx] = g_cache[kernel_idx];
128
130
}
129
- brgemm_desc_t & desc = * tl_cache[kernel_idx].desc ;
131
+ brgemm_desc_t * desc = tl_cache[kernel_idx].desc . get () ;
130
132
char *palette_buffer = tl_cache[kernel_idx].palette .get ();
131
133
132
- if (!desc. is_tmm ) {
134
+ if (!desc-> is_tmm ) {
133
135
return ;
134
136
}
135
137
@@ -159,7 +161,7 @@ void dnnl_brgemm_execute(int64_t kernel_idx, void *A, uint64_t A_offset,
159
161
tl_cache[kernel_idx] = g_cache[kernel_idx];
160
162
}
161
163
brgemm_kernel_t *kernel = tl_cache[kernel_idx].kernel ;
162
- brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc ;
164
+ brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc . get () ;
163
165
164
166
assert (kernel && " Invalid brgemm kernel pointer" );
165
167
assert (desc_ptr && " Invalid brgemm descriptor pointer" );
0 commit comments