Skip to content

Commit ade4538

Browse files
authored
fix tpp linear loop of first/next kernel (#2561)
1 parent 901b377 commit ade4538

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

csrc/cpu/tpp/kernels/TPPGEMMKrnl.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
namespace torch_ipex {
1717
namespace tpp {
1818

19-
static int large_cache_opt = false;
2019
static int use_at_vnni = false; // env2int("USE_AT_VNNI");
2120
static int FT_OPT_SIZE = env2int("FT_OPT_SIZE", 256);
2221
static int NCB_BLOCK_SIZE = env2int("NCB_BLOCK_SIZE", 64);
@@ -58,15 +57,16 @@ inline at::Tensor wt_tensor_for_first_token(at::Tensor& t) {
5857
if (dim < 5)
5958
return t;
6059
auto sizes = t.sizes();
61-
constexpr long RBS = 2;
60+
constexpr long RBS = 4;
6261
auto K1 = sizes[0];
6362
if (K1 % RBS != 0)
6463
return t;
6564
auto C1 = sizes[1];
6665
auto C2 = sizes[2];
6766
auto K2 = sizes[3];
6867
auto C3 = sizes[4];
69-
68+
if (K2 >= 32)
69+
return t;
7070
auto t_new = t.new_empty({K1 / RBS, C1, C2, RBS * K2, C3});
7171
auto in = GetVLAPtr<T>(t, {RBS, C1, C2, K2 * C3});
7272
auto out = GetVLAPtr<T>(t_new, {C1, C2, RBS, K2 * C3});
@@ -96,6 +96,7 @@ inline void tpp_linear_bias(
9696
auto in_sizes = t_in.sizes();
9797
auto wt_sizes = t_wt_.sizes();
9898
auto BS = in_sizes[0] * in_sizes[1];
99+
bool large_cache_opt = false;
99100
if (BS > FT_OPT_SIZE) { // first token compute
100101
if (wt_sizes[3] != 100) {
101102
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
@@ -183,13 +184,15 @@ inline void tpp_linear_no_bias(
183184
auto in_sizes = t_in.sizes();
184185
auto BS = in_sizes[0] * in_sizes[1];
185186
auto wt_sizes = t_wt_.sizes();
187+
bool large_cache_opt = false;
186188
if (BS > FT_OPT_SIZE) { // first token compute
187189
if (wt_sizes[3] != 100) {
188190
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
189191
wt_sizes = t_wt_.sizes();
190192
}
191193
large_cache_opt = true;
192194
}
195+
193196
auto C = in_sizes[2];
194197

195198
auto Nc = wt_sizes[1];
@@ -254,10 +257,12 @@ inline void tpp_linear_mul(
254257
auto t_wt_ = t_wt;
255258
auto in_sizes = t_in.sizes();
256259
auto BS = in_sizes[0] * in_sizes[1];
260+
bool large_cache_opt = false;
257261
if (BS > FT_OPT_SIZE) { // first token compute
258262
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
259263
large_cache_opt = true;
260264
}
265+
261266
auto wt_sizes = t_wt_.sizes();
262267
auto C = in_sizes[2];
263268

@@ -348,6 +353,7 @@ inline void tpp_linear_add_add(
348353
auto t_wt_ = t_wt;
349354
auto in_sizes = t_in.sizes();
350355
auto BS = in_sizes[0] * in_sizes[1];
356+
bool large_cache_opt = false;
351357
if (BS > FT_OPT_SIZE) { // first token compute
352358
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
353359
large_cache_opt = true;
@@ -444,10 +450,12 @@ inline void tpp_linear_gelu(
444450
auto t_wt_ = t_wt;
445451
auto in_sizes = t_in.sizes();
446452
auto BS = in_sizes[0] * in_sizes[1];
453+
bool large_cache_opt = false;
447454
if (BS > FT_OPT_SIZE) { // first token compute
448455
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
449456
large_cache_opt = true;
450457
}
458+
451459
auto wt_sizes = t_wt_.sizes();
452460
auto C = in_sizes[2];
453461

@@ -546,6 +554,7 @@ inline void tpp_fused_gate_up_proj(
546554
auto t_wt_up_ = t_wt_up;
547555
auto in_sizes = t_in.sizes();
548556
auto BS = in_sizes[0] * in_sizes[1];
557+
bool large_cache_opt = false;
549558
if (BS > FT_OPT_SIZE) { // first token compute
550559
t_wt_gate_ = wt_tensor_for_first_token<T>(t_wt_gate_);
551560
t_wt_up_ = wt_tensor_for_first_token<T>(t_wt_up_);
@@ -670,10 +679,12 @@ inline void tpp_linear_add(
670679
auto t_wt_ = t_wt;
671680
auto in_sizes = t_in.sizes();
672681
auto BS = in_sizes[0] * in_sizes[1];
682+
bool large_cache_opt = false;
673683
if (BS > FT_OPT_SIZE) { // first token compute
674684
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
675685
large_cache_opt = true;
676686
}
687+
677688
auto wt_sizes = t_wt_.sizes();
678689
auto C = in_sizes[2];
679690

@@ -761,10 +772,12 @@ inline void tpp_linear_silu(
761772
auto t_wt_ = t_wt;
762773
auto in_sizes = t_in.sizes();
763774
auto BS = in_sizes[0] * in_sizes[1];
775+
bool large_cache_opt = false;
764776
if (BS > FT_OPT_SIZE) { // first token compute
765777
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
766778
large_cache_opt = true;
767779
}
780+
768781
auto wt_sizes = t_wt_.sizes();
769782
auto C = in_sizes[2];
770783

@@ -851,10 +864,12 @@ inline void tpp_linear_relu(
851864
auto t_wt_ = t_wt;
852865
auto in_sizes = t_in.sizes();
853866
auto BS = in_sizes[0] * in_sizes[1];
867+
bool large_cache_opt = false;
854868
if (BS > FT_OPT_SIZE) { // first token compute
855869
t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
856870
large_cache_opt = true;
857871
}
872+
858873
auto wt_sizes = t_wt_.sizes();
859874
auto C = in_sizes[2];
860875

0 commit comments

Comments
 (0)