@@ -48,6 +48,8 @@ __attribute__((weak)) void print_verbose_header() {}
48
48
} // namespace dnnl
49
49
50
50
static constexpr int PALETTE_SIZE = 64 ;
51
+ static constexpr int DEFAULT_KERNEL_SIZE = 1024 ;
52
+ static constexpr int MAX_KERNEL_SIZE = 2048 ;
51
53
52
54
using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
53
55
using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
@@ -56,81 +58,78 @@ static std::shared_mutex g_brgemm_lock;
56
58
struct brgemm_cache_info_t {
57
59
brgemm_desc_t desc;
58
60
brgemm_kernel_t *kernel;
59
- std::shared_ptr <char []> palette;
61
+ std::unique_ptr <char []> palette;
60
62
};
61
63
62
- static std::vector<brgemm_cache_info_t > g_cache;
64
+ static std::vector<brgemm_cache_info_t > g_cache (DEFAULT_KERNEL_SIZE);
65
+ static int64_t g_kernel_id = -1 ;
63
66
64
67
// TODO(haixin): use syscall to determine page size?
65
68
static constexpr size_t SCRATCH_SIZE = 2 * 4096 ;
66
69
// TODO(haixin): need to use custom thread management for scratch in the future?
67
70
static thread_local char scratch[SCRATCH_SIZE] = {0 };
68
71
69
- static std::unordered_map<int64_t , brgemm_cache_info_t > &get_tl_cache () {
70
- thread_local std::unordered_map<int64_t , brgemm_cache_info_t > tl_cache;
71
- return tl_cache;
72
- }
73
-
74
72
extern " C" {
75
73
76
74
int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
77
75
int64_t LDB, int64_t LDC, int64_t stride_a,
78
76
int64_t stride_b, float beta, int64_t dtypeA,
79
77
int64_t dtypeB) {
80
- brgemm_desc_t desc;
81
- brgemm_kernel_t *kernel;
82
-
83
78
auto dnnl_dtypeA = static_cast <dnnl_data_type_t >(dtypeA);
84
79
auto dnnl_dtypeB = static_cast <dnnl_data_type_t >(dtypeB);
85
80
int64_t dtypeA_size = dnnl::impl::types::data_type_size (dnnl_dtypeA);
86
81
int64_t dtypeB_size = dnnl::impl::types::data_type_size (dnnl_dtypeB);
87
82
brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
88
83
84
+ write_lock_guard_t g (g_brgemm_lock);
85
+ g_kernel_id++;
86
+ assert (g_kernel_id < MAX_KERNEL_SIZE &&
87
+ " Too many brgemm kernels are created" );
88
+ if (g_kernel_id >= DEFAULT_KERNEL_SIZE) {
89
+ if (g_kernel_id >= (int64_t )g_cache.size ()) {
90
+ g_cache.resize (g_kernel_id + 1 );
91
+ }
92
+ }
93
+
89
94
dnnl::impl::status_t status = brgemm_desc_init (
90
- &desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd ,
91
- dnnl_dtypeA, dnnl_dtypeB, /* transA= */ false , /* transB= */ false ,
92
- brgemm_layout_t ::brgemm_row_major, 1 . 0f , beta, LDA, LDB, LDC, M, N, K ,
93
- &stride_info);
95
+ &g_cache[g_kernel_id]. desc , cpu_isa_t ::isa_undef,
96
+ brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB ,
97
+ /* transA= */ false , /* transB= */ false , brgemm_layout_t ::brgemm_row_major ,
98
+ 1 . 0f , beta, LDA, LDB, LDC, M, N, K, &stride_info);
94
99
assert (status == dnnl::impl::status::success &&
95
100
" Failed to initialize BRGEMM descriptor" );
96
101
97
- status = brgemm_kernel_create (&kernel, desc);
102
+ status = brgemm_kernel_create (&g_cache[g_kernel_id].kernel ,
103
+ g_cache[g_kernel_id].desc );
98
104
assert (status == dnnl::impl::status::success &&
99
105
" Failed to JIT BRGEMM kernel" );
100
106
101
107
brgemm_attr_t dnnl_attrs;
102
- brgemm_desc_set_attr (&desc, dnnl_attrs);
108
+ brgemm_desc_set_attr (&g_cache[g_kernel_id]. desc , dnnl_attrs);
103
109
104
- // TODO(haixin): Reuse identical palettes across kernels
105
- std::shared_ptr<char []> palette_buffer;
106
- if (desc.is_tmm ) {
107
- palette_buffer.reset (new char [PALETTE_SIZE]);
108
- dnnl::impl::status_t status = brgemm_init_tiles (desc, palette_buffer.get ());
110
+ if (g_cache[g_kernel_id].desc .is_tmm ) {
111
+ g_cache[g_kernel_id].palette .reset (new char [PALETTE_SIZE]);
112
+ status = brgemm_init_tiles (g_cache[g_kernel_id].desc ,
113
+ g_cache[g_kernel_id].palette .get ());
109
114
assert (status == dnnl::impl::status::success &&
110
115
" Failed to initialize palette for BRGEMM" );
111
116
}
112
117
113
- write_lock_guard_t g (g_brgemm_lock);
114
- g_cache.push_back (brgemm_cache_info_t {desc, kernel, palette_buffer});
115
- return g_cache.size () - 1 ;
118
+ return g_kernel_id;
116
119
}
117
120
118
121
void dnnl_brgemm_tileconfig (int64_t kernel_idx) {
119
- assert (kernel_idx >= 0 && " Invalid kernel handler" );
120
- auto &tl_cache = get_tl_cache ();
121
- auto it = tl_cache.find (kernel_idx);
122
- if (it == tl_cache.end ()) {
123
- read_lock_guard_t g (g_brgemm_lock);
124
- assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
125
- it = tl_cache.insert ({kernel_idx, g_cache[kernel_idx]}).first ;
122
+ std::unique_ptr<read_lock_guard_t > lock_guard;
123
+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
124
+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
126
125
}
127
- brgemm_desc_t &desc = it-> second . desc ;
128
- char *palette_buffer = it-> second . palette . get ( );
129
-
126
+ assert (kernel_idx >= 0 && kernel_idx < ( int64_t )g_cache. size () &&
127
+ " Invalid kernel handler " );
128
+ brgemm_desc_t &desc = g_cache[kernel_idx]. desc ;
130
129
if (!desc.is_tmm ) {
131
130
return ;
132
131
}
133
-
132
+ char *palette_buffer = g_cache[kernel_idx]. palette . get ();
134
133
assert (palette_buffer != nullptr && " Invalid palette for BRGEMM kernel" );
135
134
amx_tile_configure (palette_buffer);
136
135
}
@@ -146,35 +145,26 @@ void dnnl_brgemm_tilerelease() {
146
145
void dnnl_brgemm_execute (int64_t kernel_idx, void *A, uint64_t A_offset,
147
146
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
148
147
int num) {
149
- auto &tl_cache = get_tl_cache ();
150
- if (tl_cache.find (kernel_idx) == tl_cache.end ()) {
151
- read_lock_guard_t g (g_brgemm_lock);
152
- assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_cache.size () &&
153
- " Invalid kernel handler" );
154
- auto updated_cache =
155
- tl_cache.insert (std::make_pair (kernel_idx, g_cache[kernel_idx]));
156
- assert (updated_cache.second && " insert into thread local cache" );
148
+ std::unique_ptr<read_lock_guard_t > lock_guard;
149
+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
150
+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
157
151
}
158
- auto it = tl_cache. find (kernel_idx);
159
- brgemm_kernel_t *kernel = it-> second . kernel ;
160
- brgemm_desc_t *desc_ptr = &it-> second .desc ;
161
-
152
+ assert (kernel_idx >= 0 && kernel_idx < ( int64_t )g_cache. size () &&
153
+ " Invalid kernel handler " ) ;
154
+ brgemm_desc_t &desc = g_cache[kernel_idx] .desc ;
155
+ brgemm_kernel_t *kernel = g_cache[kernel_idx]. kernel ;
162
156
assert (kernel && " Invalid brgemm kernel pointer" );
163
- assert (desc_ptr && " Invalid brgemm descriptor pointer" );
164
-
165
157
size_t A_offset_in_bytes =
166
- dnnl::impl::types::data_type_size (desc_ptr-> dt_a ) * A_offset;
158
+ dnnl::impl::types::data_type_size (desc. dt_a ) * A_offset;
167
159
size_t B_offset_in_bytes =
168
- dnnl::impl::types::data_type_size (desc_ptr-> dt_b ) * B_offset;
160
+ dnnl::impl::types::data_type_size (desc. dt_b ) * B_offset;
169
161
size_t C_offset_in_bytes =
170
- dnnl::impl::types::data_type_size (desc_ptr->dt_c ) * C_offset;
171
-
172
- char *A_arith = (char *)A;
173
- char *B_arith = (char *)B;
174
- char *C_arith = (char *)C;
175
- brgemm_kernel_execute (kernel, num, (void *)(A_arith + A_offset_in_bytes),
176
- (void *)(B_arith + B_offset_in_bytes), nullptr ,
177
- (void *)(C_arith + C_offset_in_bytes), (void *)scratch);
162
+ dnnl::impl::types::data_type_size (desc.dt_c ) * C_offset;
163
+ char *A_arith = static_cast <char *>(A) + A_offset_in_bytes;
164
+ char *B_arith = static_cast <char *>(B) + B_offset_in_bytes;
165
+ char *C_arith = static_cast <char *>(C) + C_offset_in_bytes;
166
+ brgemm_kernel_execute (kernel, num, A_arith, B_arith, nullptr , C_arith,
167
+ scratch);
178
168
}
179
169
}
180
170
0 commit comments