@@ -53,15 +53,24 @@ using read_lock_guard_t = std::shared_lock<std::shared_mutex>;
53
53
using write_lock_guard_t = std::unique_lock<std::shared_mutex>;
54
54
static std::shared_mutex g_brgemm_lock;
55
55
56
- static std::vector<brgemm_desc_t > g_brgemm_desc_list;
57
- static std::vector<brgemm_kernel_t *> g_brgemm_kernel_list;
58
- static std::vector<std::unique_ptr<char []>> g_brgemm_palette;
56
+ struct brgemm_cache_info_t {
57
+ brgemm_desc_t desc;
58
+ brgemm_kernel_t *kernel;
59
+ std::shared_ptr<char []> palette;
60
+ };
61
+
62
+ static std::vector<brgemm_cache_info_t > g_cache;
59
63
60
64
// TODO(haixin): use syscall to determine page size?
61
65
static constexpr size_t SCRATCH_SIZE = 2 * 4096 ;
62
66
// TODO(haixin): need to use custom thread management for scratch in the future?
63
67
static thread_local char scratch[SCRATCH_SIZE] = {0 };
64
68
69
+ static std::unordered_map<int64_t , brgemm_cache_info_t > &get_tl_cache () {
70
+ thread_local std::unordered_map<int64_t , brgemm_cache_info_t > tl_cache;
71
+ return tl_cache;
72
+ }
73
+
65
74
extern " C" {
66
75
67
76
int64_t dnnl_brgemm_dispatch (int64_t M, int64_t N, int64_t K, int64_t LDA,
@@ -93,33 +102,33 @@ int64_t dnnl_brgemm_dispatch(int64_t M, int64_t N, int64_t K, int64_t LDA,
93
102
brgemm_desc_set_attr (&desc, dnnl_attrs);
94
103
95
104
// TODO(haixin): Reuse identical palettes across kernels
96
- char * palette_buffer = nullptr ;
105
+ std::shared_ptr< char []> palette_buffer;
97
106
if (desc.is_tmm ) {
98
- palette_buffer = new char [PALETTE_SIZE];
99
- dnnl::impl::status_t status = brgemm_init_tiles (desc, palette_buffer);
107
+ palette_buffer. reset ( new char [PALETTE_SIZE]) ;
108
+ dnnl::impl::status_t status = brgemm_init_tiles (desc, palette_buffer. get () );
100
109
assert (status == dnnl::impl::status::success &&
101
110
" Failed to initialize palette for BRGEMM" );
102
111
}
103
112
104
113
write_lock_guard_t g (g_brgemm_lock);
105
- g_brgemm_desc_list.push_back (desc);
106
- g_brgemm_kernel_list.push_back (kernel);
107
- g_brgemm_palette.emplace_back (palette_buffer);
108
-
109
- return g_brgemm_desc_list.size () - 1 ;
114
+ g_cache.push_back (brgemm_cache_info_t {desc, kernel, palette_buffer});
115
+ return g_cache.size () - 1 ;
110
116
}
111
117
112
118
void dnnl_brgemm_tileconfig (int64_t kernel_idx) {
113
- char *palette_buffer = nullptr ;
114
- {
119
+ assert (kernel_idx >= 0 && " Invalid kernel handler" );
120
+ auto &tl_cache = get_tl_cache ();
121
+ auto it = tl_cache.find (kernel_idx);
122
+ if (it == tl_cache.end ()) {
115
123
read_lock_guard_t g (g_brgemm_lock);
116
- assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_brgemm_desc_list.size () &&
117
- " Invalid kernel handler" );
118
- brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx];
119
- if (!desc.is_tmm ) {
120
- return ;
121
- }
122
- palette_buffer = g_brgemm_palette[kernel_idx].get ();
124
+ assert (kernel_idx < (int64_t )g_cache.size () && " Invalid kernel handler" );
125
+ it = tl_cache.insert ({kernel_idx, g_cache[kernel_idx]}).first ;
126
+ }
127
+ brgemm_desc_t &desc = it->second .desc ;
128
+ char *palette_buffer = it->second .palette .get ();
129
+
130
+ if (!desc.is_tmm ) {
131
+ return ;
123
132
}
124
133
125
134
assert (palette_buffer != nullptr && " Invalid palette for BRGEMM kernel" );
@@ -137,24 +146,29 @@ void dnnl_brgemm_tilerelease() {
137
146
void dnnl_brgemm_execute (int64_t kernel_idx, void *A, uint64_t A_offset,
138
147
void *B, uint64_t B_offset, void *C, uint64_t C_offset,
139
148
int num) {
140
- brgemm_kernel_t *kernel = nullptr ;
141
- size_t A_offset_in_bytes;
142
- size_t B_offset_in_bytes;
143
- size_t C_offset_in_bytes;
144
- {
149
+ auto &tl_cache = get_tl_cache ();
150
+ if (tl_cache.find (kernel_idx) == tl_cache.end ()) {
145
151
read_lock_guard_t g (g_brgemm_lock);
146
- assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_brgemm_desc_list .size () &&
152
+ assert (kernel_idx >= 0 && kernel_idx < (int64_t )g_cache .size () &&
147
153
" Invalid kernel handler" );
148
-
149
- brgemm_desc_t &desc = g_brgemm_desc_list[kernel_idx];
150
- kernel = g_brgemm_kernel_list[kernel_idx];
151
-
152
- A_offset_in_bytes = dnnl::impl::types::data_type_size (desc.dt_a ) * A_offset;
153
- B_offset_in_bytes = dnnl::impl::types::data_type_size (desc.dt_b ) * B_offset;
154
- C_offset_in_bytes = dnnl::impl::types::data_type_size (desc.dt_c ) * C_offset;
154
+ auto updated_cache =
155
+ tl_cache.insert (std::make_pair (kernel_idx, g_cache[kernel_idx]));
156
+ assert (updated_cache.second && " insert into thread local cache" );
155
157
}
158
+ auto it = tl_cache.find (kernel_idx);
159
+ brgemm_kernel_t *kernel = it->second .kernel ;
160
+ brgemm_desc_t *desc_ptr = &it->second .desc ;
156
161
157
162
assert (kernel && " Invalid brgemm kernel pointer" );
163
+ assert (desc_ptr && " Invalid brgemm descriptor pointer" );
164
+
165
+ size_t A_offset_in_bytes =
166
+ dnnl::impl::types::data_type_size (desc_ptr->dt_a ) * A_offset;
167
+ size_t B_offset_in_bytes =
168
+ dnnl::impl::types::data_type_size (desc_ptr->dt_b ) * B_offset;
169
+ size_t C_offset_in_bytes =
170
+ dnnl::impl::types::data_type_size (desc_ptr->dt_c ) * C_offset;
171
+
158
172
char *A_arith = (char *)A;
159
173
char *B_arith = (char *)B;
160
174
char *C_arith = (char *)C;
0 commit comments