@@ -48,93 +48,90 @@ __attribute__((weak)) void print_verbose_header() {}
48
48
} // namespace dnnl
49
49
50
50
static constexpr int PALETTE_SIZE = 64 ;
51
+ static constexpr int DEFAULT_KERNEL_SIZE = 1024 ;
51
52
52
53
using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
53
54
using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
54
55
static std::shared_mutex g_brgemm_lock;
55
56
56
57
struct brgemm_cache_info_t {
57
- std::shared_ptr< brgemm_desc_t > desc;
58
+ brgemm_desc_t desc;
58
59
brgemm_kernel_t *kernel;
59
- std::shared_ptr <char []> palette;
60
+ std::unique_ptr <char []> palette;
60
61
};
61
62
62
- static std::vector<brgemm_cache_info_t > g_cache;
63
+ static std::vector<brgemm_cache_info_t > g_cache (DEFAULT_KERNEL_SIZE);
64
+ static int64_t kernel_id = -1 ;
63
65
64
66
// TODO(haixin): use syscall to determine page size?
65
67
static constexpr size_t SCRATCH_SIZE = 2 * 4096 ;
66
68
// TODO(haixin): need to use custom thread management for scratch in the future?
67
69
static thread_local char scratch[SCRATCH_SIZE] = {0 };
68
70
69
- static std::vector<brgemm_cache_info_t > &get_tl_cache () {
70
- thread_local std::vector<brgemm_cache_info_t > tl_cache;
71
- return tl_cache;
72
- }
73
-
74
71
extern " C" {
75
72
76
73
int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
77
74
int64_t LDB, int64_t LDC, int64_t stride_a,
78
75
int64_t stride_b, float beta, int64_t dtypeA,
79
76
int64_t dtypeB) {
80
- std::shared_ptr<brgemm_desc_t > desc_ptr = std::make_shared<brgemm_desc_t >();
81
- brgemm_desc_t *desc = desc_ptr.get ();
82
- brgemm_kernel_t *kernel;
83
77
auto dnnl_dtypeA = static_cast <dnnl_data_type_t >(dtypeA);
84
78
auto dnnl_dtypeB = static_cast <dnnl_data_type_t >(dtypeB);
85
79
int64_t dtypeA_size = dnnl::impl::types::data_type_size (dnnl_dtypeA);
86
80
int64_t dtypeB_size = dnnl::impl::types::data_type_size (dnnl_dtypeB);
87
81
brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
88
82
83
+ write_lock_guard_t g (g_brgemm_lock);
84
+ kernel_id++;
85
+
86
+ if (kernel_id >= DEFAULT_KERNEL_SIZE) {
87
+ if (kernel_id >= (int64_t )g_cache.size ()) {
88
+ g_cache.resize (kernel_id + 1 );
89
+ }
90
+ }
91
+
89
92
dnnl::impl::status_t status = brgemm_desc_init (
90
- desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA,
91
- dnnl_dtypeB, /* transA=*/ false , /* transB=*/ false ,
92
- brgemm_layout_t ::brgemm_row_major, 1 .0f , beta, LDA, LDB, LDC, M, N, K,
93
- &stride_info);
94
- assert (status == dnnl::impl::status::success &&
95
- " Failed to initialize BRGEMM descriptor" );
93
+ &g_cache[kernel_id].desc , cpu_isa_t ::isa_undef,
94
+ brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB,
95
+ /* transA=*/ false , /* transB=*/ false , brgemm_layout_t ::brgemm_row_major,
96
+ 1 .0f , beta, LDA, LDB, LDC, M, N, K, &stride_info);
97
+ if (status != dnnl::impl::status::success) {
98
+ return -1 ;
99
+ }
96
100
97
- status = brgemm_kernel_create (&kernel, *desc);
98
- assert (status == dnnl::impl::status::success &&
99
- " Failed to JIT BRGEMM kernel" );
101
+ status =
102
+ brgemm_kernel_create (&g_cache[kernel_id].kernel , g_cache[kernel_id].desc );
103
+ if (status != dnnl::impl::status::success) {
104
+ return -1 ;
105
+ }
100
106
101
107
brgemm_attr_t dnnl_attrs;
102
- brgemm_desc_set_attr (desc, dnnl_attrs);
103
-
104
- // TODO(haixin): Reuse identical palettes across kernels
105
- std::shared_ptr<char []> palette_buffer;
106
- if (desc->is_tmm ) {
107
- palette_buffer.reset (new char [PALETTE_SIZE]);
108
- dnnl::impl::status_t status =
109
- brgemm_init_tiles (*desc, palette_buffer.get ());
110
- assert (status == dnnl::impl::status::success &&
111
- " Failed to initialize palette for BRGEMM" );
108
+ brgemm_desc_set_attr (&g_cache[kernel_id].desc , dnnl_attrs);
109
+
110
+ if (g_cache[kernel_id].desc .is_tmm ) {
111
+ g_cache[kernel_id].palette .reset (new char [PALETTE_SIZE]);
112
+ status = brgemm_init_tiles (g_cache[kernel_id].desc ,
113
+ g_cache[kernel_id].palette .get ());
114
+ if (status != dnnl::impl::status::success) {
115
+ return -1 ;
116
+ }
112
117
}
113
118
114
- write_lock_guard_t g (g_brgemm_lock);
115
- g_cache.push_back (brgemm_cache_info_t {desc_ptr, kernel, palette_buffer});
116
- return g_cache.size () - 1 ;
119
+ return kernel_id;
117
120
}
118
121
119
122
void dnnl_brgemm_tileconfig (int64_t kernel_idx) {
120
- assert (kernel_idx >= 0 && " Invalid kernel handler" );
121
- auto &tl_cache = get_tl_cache ();
122
- if (kernel_idx >= (int64_t )tl_cache.size () ||
123
- tl_cache[kernel_idx].kernel == nullptr ) {
124
- read_lock_guard_t g (g_brgemm_lock);
125
- assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
126
- if (kernel_idx >= (int64_t )tl_cache.size ()) {
127
- tl_cache.resize (kernel_idx + 1 );
128
- }
129
- tl_cache[kernel_idx] = g_cache[kernel_idx];
123
+ // Declare the lock guard outside the if block to extend its lifetime
124
+ std::unique_ptr<read_lock_guard_t > lock_guard;
125
+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
126
+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
130
127
}
131
- brgemm_desc_t *desc = tl_cache[ kernel_idx]. desc . get ();
132
- char *palette_buffer = tl_cache[kernel_idx]. palette . get ( );
133
-
134
- if (!desc-> is_tmm ) {
128
+ assert (kernel_idx >= 0 && kernel_idx < ( int64_t )g_cache. size () &&
129
+ " Invalid kernel handler " );
130
+ brgemm_desc_t &desc = g_cache[kernel_idx]. desc ;
131
+ if (!desc. is_tmm ) {
135
132
return ;
136
133
}
137
-
134
+ char *palette_buffer = g_cache[kernel_idx]. palette . get ();
138
135
assert (palette_buffer != nullptr && " Invalid palette for BRGEMM kernel" );
139
136
amx_tile_configure (palette_buffer);
140
137
}
@@ -150,35 +147,27 @@ void dnnl_brgemm_tilerelease() {
150
147
void dnnl_brgemm_execute (int64_t kernel_idx, void *A, uint64_t A_offset,
151
148
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
152
149
int num) {
153
- auto &tl_cache = get_tl_cache ();
154
- if (kernel_idx >= (int64_t )tl_cache.size () ||
155
- tl_cache[kernel_idx].kernel == nullptr ) {
156
- read_lock_guard_t g (g_brgemm_lock);
157
- assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
158
- if (kernel_idx >= (int64_t )tl_cache.size ()) {
159
- tl_cache.resize (kernel_idx + 1 );
160
- }
161
- tl_cache[kernel_idx] = g_cache[kernel_idx];
150
+ // Acquire the lock only if needed
151
+ std::unique_ptr<read_lock_guard_t > lock_guard;
152
+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
153
+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
162
154
}
163
- brgemm_kernel_t *kernel = tl_cache[kernel_idx].kernel ;
164
- brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc .get ();
165
-
155
+ assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_cache.size () &&
156
+ " Invalid kernel handler" );
157
+ brgemm_desc_t &desc = g_cache[kernel_idx].desc ;
158
+ brgemm_kernel_t *kernel = g_cache[kernel_idx].kernel ;
166
159
assert (kernel && " Invalid brgemm kernel pointer" );
167
- assert (desc_ptr && " Invalid brgemm descriptor pointer" );
168
-
169
160
size_t A_offset_in_bytes =
170
- dnnl::impl::types::data_type_size (desc_ptr-> dt_a ) * A_offset;
161
+ dnnl::impl::types::data_type_size (desc. dt_a ) * A_offset;
171
162
size_t B_offset_in_bytes =
172
- dnnl::impl::types::data_type_size (desc_ptr-> dt_b ) * B_offset;
163
+ dnnl::impl::types::data_type_size (desc. dt_b ) * B_offset;
173
164
size_t C_offset_in_bytes =
174
- dnnl::impl::types::data_type_size (desc_ptr->dt_c ) * C_offset;
175
-
176
- char *A_arith = (char *)A;
177
- char *B_arith = (char *)B;
178
- char *C_arith = (char *)C;
179
- brgemm_kernel_execute (kernel, num, (void *)(A_arith + A_offset_in_bytes),
180
- (void *)(B_arith + B_offset_in_bytes), nullptr ,
181
- (void *)(C_arith + C_offset_in_bytes), (void *)scratch);
165
+ dnnl::impl::types::data_type_size (desc.dt_c ) * C_offset;
166
+ char *A_arith = static_cast <char *>(A) + A_offset_in_bytes;
167
+ char *B_arith = static_cast <char *>(B) + B_offset_in_bytes;
168
+ char *C_arith = static_cast <char *>(C) + C_offset_in_bytes;
169
+ brgemm_kernel_execute (kernel, num, A_arith, B_arith, nullptr , C_arith,
170
+ scratch);
182
171
}
183
172
}
184
173
0 commit comments