@@ -48,93 +48,86 @@ __attribute__((weak)) void print_verbose_header() {}
48
48
} // namespace dnnl
49
49
50
50
static constexpr int PALETTE_SIZE = 64 ;
51
+ static constexpr int DEFAULT_KERNEL_SIZE = 1024 ;
51
52
52
53
using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
53
54
using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
54
55
static std::shared_mutex g_brgemm_lock;
55
56
56
57
struct brgemm_cache_info_t {
57
- std::shared_ptr< brgemm_desc_t > desc;
58
+ brgemm_desc_t desc;
58
59
brgemm_kernel_t *kernel;
59
- std::shared_ptr <char []> palette;
60
+ std::unique_ptr <char []> palette;
60
61
};
61
62
62
- static std::vector<brgemm_cache_info_t > g_cache;
63
+ static std::vector<brgemm_cache_info_t > g_cache (DEFAULT_KERNEL_SIZE);
64
+ static int64_t kernel_id = -1 ;
63
65
64
66
// TODO(haixin): use syscall to determine page size?
65
67
static constexpr size_t SCRATCH_SIZE = 2 * 4096 ;
66
68
// TODO(haixin): need to use custom thread management for scratch in the future?
67
69
static thread_local char scratch[SCRATCH_SIZE] = {0 };
68
70
69
- static std::vector<brgemm_cache_info_t > &get_tl_cache () {
70
- thread_local std::vector<brgemm_cache_info_t > tl_cache;
71
- return tl_cache;
72
- }
73
-
74
71
extern " C" {
75
72
76
73
int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
77
74
int64_t LDB, int64_t LDC, int64_t stride_a,
78
75
int64_t stride_b, float beta, int64_t dtypeA,
79
76
int64_t dtypeB) {
80
- std::shared_ptr<brgemm_desc_t > desc_ptr = std::make_shared<brgemm_desc_t >();
81
- brgemm_desc_t *desc = desc_ptr.get ();
82
- brgemm_kernel_t *kernel;
83
77
auto dnnl_dtypeA = static_cast <dnnl_data_type_t >(dtypeA);
84
78
auto dnnl_dtypeB = static_cast <dnnl_data_type_t >(dtypeB);
85
79
int64_t dtypeA_size = dnnl::impl::types::data_type_size (dnnl_dtypeA);
86
80
int64_t dtypeB_size = dnnl::impl::types::data_type_size (dnnl_dtypeB);
87
81
brgemm_strides_t stride_info{stride_a * dtypeA_size, stride_b * dtypeB_size};
88
82
83
+ write_lock_guard_t g (g_brgemm_lock);
84
+ kernel_id++;
85
+
86
+ if (kernel_id >= DEFAULT_KERNEL_SIZE) {
87
+ if (kernel_id >= (int64_t )g_cache.size ()) {
88
+ g_cache.resize (kernel_id + 1 );
89
+ }
90
+ }
91
+
89
92
dnnl::impl::status_t status = brgemm_desc_init (
90
- desc, cpu_isa_t ::isa_undef, brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA ,
91
- dnnl_dtypeB, /* transA= */ false , /* transB= */ false ,
92
- brgemm_layout_t ::brgemm_row_major, 1 . 0f , beta, LDA, LDB, LDC, M, N, K ,
93
- &stride_info);
93
+ &g_cache[kernel_id]. desc , cpu_isa_t ::isa_undef,
94
+ brgemm_batch_kind_t ::brgemm_strd, dnnl_dtypeA, dnnl_dtypeB ,
95
+ /* transA= */ false , /* transB= */ false , brgemm_layout_t ::brgemm_row_major ,
96
+ 1 . 0f , beta, LDA, LDB, LDC, M, N, K, &stride_info);
94
97
assert (status == dnnl::impl::status::success &&
95
98
" Failed to initialize BRGEMM descriptor" );
96
99
97
- status = brgemm_kernel_create (&kernel, *desc);
100
+ status =
101
+ brgemm_kernel_create (&g_cache[kernel_id].kernel , g_cache[kernel_id].desc );
98
102
assert (status == dnnl::impl::status::success &&
99
103
" Failed to JIT BRGEMM kernel" );
100
104
101
105
brgemm_attr_t dnnl_attrs;
102
- brgemm_desc_set_attr (desc, dnnl_attrs);
103
-
104
- // TODO(haixin): Reuse identical palettes across kernels
105
- std::shared_ptr<char []> palette_buffer;
106
- if (desc->is_tmm ) {
107
- palette_buffer.reset (new char [PALETTE_SIZE]);
108
- dnnl::impl::status_t status =
109
- brgemm_init_tiles (*desc, palette_buffer.get ());
106
+ brgemm_desc_set_attr (&g_cache[kernel_id].desc , dnnl_attrs);
107
+
108
+ if (g_cache[kernel_id].desc .is_tmm ) {
109
+ g_cache[kernel_id].palette .reset (new char [PALETTE_SIZE]);
110
+ status = brgemm_init_tiles (g_cache[kernel_id].desc ,
111
+ g_cache[kernel_id].palette .get ());
110
112
assert (status == dnnl::impl::status::success &&
111
113
" Failed to initialize palette for BRGEMM" );
112
114
}
113
115
114
- write_lock_guard_t g (g_brgemm_lock);
115
- g_cache.push_back (brgemm_cache_info_t {desc_ptr, kernel, palette_buffer});
116
- return g_cache.size () - 1 ;
116
+ return kernel_id;
117
117
}
118
118
119
119
void dnnl_brgemm_tileconfig (int64_t kernel_idx) {
120
- assert (kernel_idx >= 0 && " Invalid kernel handler" );
121
- auto &tl_cache = get_tl_cache ();
122
- if (kernel_idx >= (int64_t )tl_cache.size () ||
123
- tl_cache[kernel_idx].kernel == nullptr ) {
124
- read_lock_guard_t g (g_brgemm_lock);
125
- assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
126
- if (kernel_idx >= (int64_t )tl_cache.size ()) {
127
- tl_cache.resize (kernel_idx + 1 );
128
- }
129
- tl_cache[kernel_idx] = g_cache[kernel_idx];
120
+ std::unique_ptr<read_lock_guard_t > lock_guard;
121
+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
122
+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
130
123
}
131
- brgemm_desc_t *desc = tl_cache[ kernel_idx]. desc . get ();
132
- char *palette_buffer = tl_cache[kernel_idx]. palette . get ( );
133
-
134
- if (!desc-> is_tmm ) {
124
+ assert (kernel_idx >= 0 && kernel_idx < ( int64_t )g_cache. size () &&
125
+ " Invalid kernel handler " );
126
+ brgemm_desc_t &desc = g_cache[kernel_idx]. desc ;
127
+ if (!desc. is_tmm ) {
135
128
return ;
136
129
}
137
-
130
+ char *palette_buffer = g_cache[kernel_idx]. palette . get ();
138
131
assert (palette_buffer != nullptr && " Invalid palette for BRGEMM kernel" );
139
132
amx_tile_configure (palette_buffer);
140
133
}
@@ -150,35 +143,26 @@ void dnnl_brgemm_tilerelease() {
150
143
void dnnl_brgemm_execute (int64_t kernel_idx, void *A, uint64_t A_offset,
151
144
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
152
145
int num) {
153
- auto &tl_cache = get_tl_cache ();
154
- if (kernel_idx >= (int64_t )tl_cache.size () ||
155
- tl_cache[kernel_idx].kernel == nullptr ) {
156
- read_lock_guard_t g (g_brgemm_lock);
157
- assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
158
- if (kernel_idx >= (int64_t )tl_cache.size ()) {
159
- tl_cache.resize (kernel_idx + 1 );
160
- }
161
- tl_cache[kernel_idx] = g_cache[kernel_idx];
146
+ std::unique_ptr<read_lock_guard_t > lock_guard;
147
+ if (kernel_idx >= DEFAULT_KERNEL_SIZE) {
148
+ lock_guard = std::make_unique<read_lock_guard_t >(g_brgemm_lock);
162
149
}
163
- brgemm_kernel_t *kernel = tl_cache[kernel_idx].kernel ;
164
- brgemm_desc_t *desc_ptr = tl_cache[kernel_idx].desc .get ();
165
-
150
+ assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_cache.size () &&
151
+ " Invalid kernel handler" );
152
+ brgemm_desc_t &desc = g_cache[kernel_idx].desc ;
153
+ brgemm_kernel_t *kernel = g_cache[kernel_idx].kernel ;
166
154
assert (kernel && " Invalid brgemm kernel pointer" );
167
- assert (desc_ptr && " Invalid brgemm descriptor pointer" );
168
-
169
155
size_t A_offset_in_bytes =
170
- dnnl::impl::types::data_type_size (desc_ptr-> dt_a ) * A_offset;
156
+ dnnl::impl::types::data_type_size (desc. dt_a ) * A_offset;
171
157
size_t B_offset_in_bytes =
172
- dnnl::impl::types::data_type_size (desc_ptr-> dt_b ) * B_offset;
158
+ dnnl::impl::types::data_type_size (desc. dt_b ) * B_offset;
173
159
size_t C_offset_in_bytes =
174
- dnnl::impl::types::data_type_size (desc_ptr->dt_c ) * C_offset;
175
-
176
- char *A_arith = (char *)A;
177
- char *B_arith = (char *)B;
178
- char *C_arith = (char *)C;
179
- brgemm_kernel_execute (kernel, num, (void *)(A_arith + A_offset_in_bytes),
180
- (void *)(B_arith + B_offset_in_bytes), nullptr ,
181
- (void *)(C_arith + C_offset_in_bytes), (void *)scratch);
160
+ dnnl::impl::types::data_type_size (desc.dt_c ) * C_offset;
161
+ char *A_arith = static_cast <char *>(A) + A_offset_in_bytes;
162
+ char *B_arith = static_cast <char *>(B) + B_offset_in_bytes;
163
+ char *C_arith = static_cast <char *>(C) + C_offset_in_bytes;
164
+ brgemm_kernel_execute (kernel, num, A_arith, B_arith, nullptr , C_arith,
165
+ scratch);
182
166
}
183
167
}
184
168
0 commit comments