|
16 | 16 | namespace torch_ipex {
|
17 | 17 | namespace tpp {
|
18 | 18 |
|
19 |
| -static int large_cache_opt = false; |
20 | 19 | static int use_at_vnni = false; // env2int("USE_AT_VNNI");
|
21 | 20 | static int FT_OPT_SIZE = env2int("FT_OPT_SIZE", 256);
|
22 | 21 | static int NCB_BLOCK_SIZE = env2int("NCB_BLOCK_SIZE", 64);
|
@@ -58,15 +57,16 @@ inline at::Tensor wt_tensor_for_first_token(at::Tensor& t) {
|
58 | 57 | if (dim < 5)
|
59 | 58 | return t;
|
60 | 59 | auto sizes = t.sizes();
|
61 |
| - constexpr long RBS = 2; |
| 60 | + constexpr long RBS = 4; |
62 | 61 | auto K1 = sizes[0];
|
63 | 62 | if (K1 % RBS != 0)
|
64 | 63 | return t;
|
65 | 64 | auto C1 = sizes[1];
|
66 | 65 | auto C2 = sizes[2];
|
67 | 66 | auto K2 = sizes[3];
|
68 | 67 | auto C3 = sizes[4];
|
69 |
| - |
| 68 | + if (K2 >= 32) |
| 69 | + return t; |
70 | 70 | auto t_new = t.new_empty({K1 / RBS, C1, C2, RBS * K2, C3});
|
71 | 71 | auto in = GetVLAPtr<T>(t, {RBS, C1, C2, K2 * C3});
|
72 | 72 | auto out = GetVLAPtr<T>(t_new, {C1, C2, RBS, K2 * C3});
|
@@ -96,6 +96,7 @@ inline void tpp_linear_bias(
|
96 | 96 | auto in_sizes = t_in.sizes();
|
97 | 97 | auto wt_sizes = t_wt_.sizes();
|
98 | 98 | auto BS = in_sizes[0] * in_sizes[1];
|
| 99 | + bool large_cache_opt = false; |
99 | 100 | if (BS > FT_OPT_SIZE) { // first token compute
|
100 | 101 | if (wt_sizes[3] != 100) {
|
101 | 102 | t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
|
@@ -183,13 +184,15 @@ inline void tpp_linear_no_bias(
|
183 | 184 | auto in_sizes = t_in.sizes();
|
184 | 185 | auto BS = in_sizes[0] * in_sizes[1];
|
185 | 186 | auto wt_sizes = t_wt_.sizes();
|
| 187 | + bool large_cache_opt = false; |
186 | 188 | if (BS > FT_OPT_SIZE) { // first token compute
|
187 | 189 | if (wt_sizes[3] != 100) {
|
188 | 190 | t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
|
189 | 191 | wt_sizes = t_wt_.sizes();
|
190 | 192 | }
|
191 | 193 | large_cache_opt = true;
|
192 | 194 | }
|
| 195 | + |
193 | 196 | auto C = in_sizes[2];
|
194 | 197 |
|
195 | 198 | auto Nc = wt_sizes[1];
|
@@ -254,10 +257,12 @@ inline void tpp_linear_mul(
|
254 | 257 | auto t_wt_ = t_wt;
|
255 | 258 | auto in_sizes = t_in.sizes();
|
256 | 259 | auto BS = in_sizes[0] * in_sizes[1];
|
| 260 | + bool large_cache_opt = false; |
257 | 261 | if (BS > FT_OPT_SIZE) { // first token compute
|
258 | 262 | t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
|
259 | 263 | large_cache_opt = true;
|
260 | 264 | }
|
| 265 | + |
261 | 266 | auto wt_sizes = t_wt_.sizes();
|
262 | 267 | auto C = in_sizes[2];
|
263 | 268 |
|
@@ -348,6 +353,7 @@ inline void tpp_linear_add_add(
|
348 | 353 | auto t_wt_ = t_wt;
|
349 | 354 | auto in_sizes = t_in.sizes();
|
350 | 355 | auto BS = in_sizes[0] * in_sizes[1];
|
| 356 | + bool large_cache_opt = false; |
351 | 357 | if (BS > FT_OPT_SIZE) { // first token compute
|
352 | 358 | t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
|
353 | 359 | large_cache_opt = true;
|
@@ -444,10 +450,12 @@ inline void tpp_linear_gelu(
|
444 | 450 | auto t_wt_ = t_wt;
|
445 | 451 | auto in_sizes = t_in.sizes();
|
446 | 452 | auto BS = in_sizes[0] * in_sizes[1];
|
| 453 | + bool large_cache_opt = false; |
447 | 454 | if (BS > FT_OPT_SIZE) { // first token compute
|
448 | 455 | t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
|
449 | 456 | large_cache_opt = true;
|
450 | 457 | }
|
| 458 | + |
451 | 459 | auto wt_sizes = t_wt_.sizes();
|
452 | 460 | auto C = in_sizes[2];
|
453 | 461 |
|
@@ -546,6 +554,7 @@ inline void tpp_fused_gate_up_proj(
|
546 | 554 | auto t_wt_up_ = t_wt_up;
|
547 | 555 | auto in_sizes = t_in.sizes();
|
548 | 556 | auto BS = in_sizes[0] * in_sizes[1];
|
| 557 | + bool large_cache_opt = false; |
549 | 558 | if (BS > FT_OPT_SIZE) { // first token compute
|
550 | 559 | t_wt_gate_ = wt_tensor_for_first_token<T>(t_wt_gate_);
|
551 | 560 | t_wt_up_ = wt_tensor_for_first_token<T>(t_wt_up_);
|
@@ -670,10 +679,12 @@ inline void tpp_linear_add(
|
670 | 679 | auto t_wt_ = t_wt;
|
671 | 680 | auto in_sizes = t_in.sizes();
|
672 | 681 | auto BS = in_sizes[0] * in_sizes[1];
|
| 682 | + bool large_cache_opt = false; |
673 | 683 | if (BS > FT_OPT_SIZE) { // first token compute
|
674 | 684 | t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
|
675 | 685 | large_cache_opt = true;
|
676 | 686 | }
|
| 687 | + |
677 | 688 | auto wt_sizes = t_wt_.sizes();
|
678 | 689 | auto C = in_sizes[2];
|
679 | 690 |
|
@@ -761,10 +772,12 @@ inline void tpp_linear_silu(
|
761 | 772 | auto t_wt_ = t_wt;
|
762 | 773 | auto in_sizes = t_in.sizes();
|
763 | 774 | auto BS = in_sizes[0] * in_sizes[1];
|
| 775 | + bool large_cache_opt = false; |
764 | 776 | if (BS > FT_OPT_SIZE) { // first token compute
|
765 | 777 | t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
|
766 | 778 | large_cache_opt = true;
|
767 | 779 | }
|
| 780 | + |
768 | 781 | auto wt_sizes = t_wt_.sizes();
|
769 | 782 | auto C = in_sizes[2];
|
770 | 783 |
|
@@ -851,10 +864,12 @@ inline void tpp_linear_relu(
|
851 | 864 | auto t_wt_ = t_wt;
|
852 | 865 | auto in_sizes = t_in.sizes();
|
853 | 866 | auto BS = in_sizes[0] * in_sizes[1];
|
| 867 | + bool large_cache_opt = false; |
854 | 868 | if (BS > FT_OPT_SIZE) { // first token compute
|
855 | 869 | t_wt_ = wt_tensor_for_first_token<T>(t_wt_);
|
856 | 870 | large_cache_opt = true;
|
857 | 871 | }
|
| 872 | + |
858 | 873 | auto wt_sizes = t_wt_.sizes();
|
859 | 874 | auto C = in_sizes[2];
|
860 | 875 |
|
|
0 commit comments