101
101
#endif
102
102
103
103
// bump if necessary
104
- #define LLAMA_MAX_NODES 8192
105
104
#define LLAMA_MAX_LAYERS 512
106
105
#define LLAMA_MAX_EXPERTS 160 // DeepSeekV2
107
106
@@ -3567,6 +3566,15 @@ namespace GGUFMeta {
3567
3566
3568
3567
using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>;
3569
3568
3569
+ // TODO: update when needed or think of some clever automatic way to do this
3570
+ static size_t llama_model_max_nodes(const llama_model & /*model*/) {
3571
+ //if (model.arch == LLM_ARCH_LLAMA && model.hparams.n_layer > ??) { // llama-3 405B
3572
+ // return 32768;
3573
+ //}
3574
+
3575
+ return 8192;
3576
+ }
3577
+
3570
3578
struct llama_model_loader {
3571
3579
int n_kv = 0;
3572
3580
int n_tensors = 0;
@@ -8396,7 +8404,7 @@ struct llm_build_context {
8396
8404
}
8397
8405
8398
8406
struct ggml_cgraph * build_k_shift() {
8399
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
8407
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
8400
8408
8401
8409
GGML_ASSERT(kv_self.size == n_ctx);
8402
8410
@@ -8427,7 +8435,7 @@ struct llm_build_context {
8427
8435
}
8428
8436
8429
8437
struct ggml_cgraph * build_s_copy() {
8430
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
8438
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
8431
8439
8432
8440
GGML_ASSERT(kv_self.recurrent);
8433
8441
@@ -8450,7 +8458,7 @@ struct llm_build_context {
8450
8458
}
8451
8459
8452
8460
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
8453
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
8461
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
8454
8462
8455
8463
for (uint32_t i = 0; i < ids.size(); ++i) {
8456
8464
const uint32_t id = ids[i];
@@ -8691,7 +8699,7 @@ struct llm_build_context {
8691
8699
}
8692
8700
8693
8701
struct ggml_cgraph * build_llama() {
8694
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
8702
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
8695
8703
8696
8704
// mutable variable, needed during the last layer of the computation to skip unused tokens
8697
8705
int32_t n_tokens = this->n_tokens;
@@ -8834,7 +8842,7 @@ struct llm_build_context {
8834
8842
}
8835
8843
8836
8844
struct ggml_cgraph * build_baichuan() {
8837
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
8845
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
8838
8846
8839
8847
const int64_t n_embd_head = hparams.n_embd_head_v;
8840
8848
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -8949,7 +8957,7 @@ struct llm_build_context {
8949
8957
}
8950
8958
8951
8959
struct ggml_cgraph * build_xverse() {
8952
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
8960
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
8953
8961
8954
8962
const int64_t n_embd_head = hparams.n_embd_head_v;
8955
8963
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9052,7 +9060,7 @@ struct llm_build_context {
9052
9060
}
9053
9061
9054
9062
struct ggml_cgraph * build_falcon() {
9055
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
9063
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
9056
9064
9057
9065
const int64_t n_embd_head = hparams.n_embd_head_v;
9058
9066
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -9172,7 +9180,7 @@ struct llm_build_context {
9172
9180
}
9173
9181
9174
9182
struct ggml_cgraph * build_grok() {
9175
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
9183
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
9176
9184
9177
9185
// mutable variable, needed during the last layer of the computation to skip unused tokens
9178
9186
int32_t n_tokens = this->n_tokens;
@@ -9329,7 +9337,7 @@ struct llm_build_context {
9329
9337
}
9330
9338
9331
9339
struct ggml_cgraph * build_dbrx() {
9332
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
9340
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
9333
9341
9334
9342
// mutable variable, needed during the last layer of the computation to skip unused tokens
9335
9343
int32_t n_tokens = this->n_tokens;
@@ -9455,7 +9463,7 @@ struct llm_build_context {
9455
9463
}
9456
9464
9457
9465
struct ggml_cgraph * build_starcoder() {
9458
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
9466
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
9459
9467
9460
9468
const int64_t n_embd_head = hparams.n_embd_head_v;
9461
9469
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -9559,7 +9567,7 @@ struct llm_build_context {
9559
9567
}
9560
9568
9561
9569
struct ggml_cgraph * build_refact() {
9562
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
9570
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
9563
9571
9564
9572
const int64_t n_embd_head = hparams.n_embd_head_v;
9565
9573
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9653,7 +9661,7 @@ struct llm_build_context {
9653
9661
}
9654
9662
9655
9663
struct ggml_cgraph * build_bert() {
9656
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
9664
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
9657
9665
9658
9666
const int64_t n_embd_head = hparams.n_embd_head_v;
9659
9667
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -9847,7 +9855,7 @@ struct llm_build_context {
9847
9855
}
9848
9856
9849
9857
struct ggml_cgraph * build_bloom() {
9850
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
9858
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
9851
9859
9852
9860
const int64_t n_embd_head = hparams.n_embd_head_v;
9853
9861
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -9948,7 +9956,7 @@ struct llm_build_context {
9948
9956
}
9949
9957
9950
9958
struct ggml_cgraph * build_mpt() {
9951
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
9959
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
9952
9960
9953
9961
const int64_t n_embd_head = hparams.n_embd_head_v;
9954
9962
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -10238,7 +10246,7 @@ struct llm_build_context {
10238
10246
}
10239
10247
10240
10248
struct ggml_cgraph * build_qwen() {
10241
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
10249
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
10242
10250
10243
10251
const int64_t n_embd_head = hparams.n_embd_head_v;
10244
10252
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10350,7 +10358,7 @@ struct llm_build_context {
10350
10358
}
10351
10359
10352
10360
struct ggml_cgraph * build_qwen2() {
10353
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
10361
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
10354
10362
10355
10363
const int64_t n_embd_head = hparams.n_embd_head_v;
10356
10364
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -10462,7 +10470,7 @@ struct llm_build_context {
10462
10470
}
10463
10471
10464
10472
struct ggml_cgraph * build_qwen2moe() {
10465
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
10473
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
10466
10474
10467
10475
// mutable variable, needed during the last layer of the computation to skip unused tokens
10468
10476
int32_t n_tokens = this->n_tokens;
@@ -10608,7 +10616,7 @@ struct llm_build_context {
10608
10616
}
10609
10617
10610
10618
struct ggml_cgraph * build_phi2() {
10611
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
10619
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
10612
10620
10613
10621
const int64_t n_embd_head = hparams.n_embd_head_v;
10614
10622
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -10729,7 +10737,7 @@ struct llm_build_context {
10729
10737
}
10730
10738
10731
10739
struct ggml_cgraph * build_phi3() {
10732
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
10740
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
10733
10741
10734
10742
const int64_t n_embd_head = hparams.n_embd_head_v;
10735
10743
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -10961,7 +10969,7 @@ struct llm_build_context {
10961
10969
}
10962
10970
10963
10971
struct ggml_cgraph * build_gpt2() {
10964
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
10972
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
10965
10973
10966
10974
const int64_t n_embd_head = hparams.n_embd_head_v;
10967
10975
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -11066,7 +11074,7 @@ struct llm_build_context {
11066
11074
}
11067
11075
11068
11076
struct ggml_cgraph * build_codeshell() {
11069
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
11077
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
11070
11078
11071
11079
const int64_t n_embd_head = hparams.n_embd_head_v;
11072
11080
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -11177,7 +11185,7 @@ struct llm_build_context {
11177
11185
}
11178
11186
11179
11187
struct ggml_cgraph * build_orion() {
11180
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
11188
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
11181
11189
11182
11190
const int64_t n_embd_head = hparams.n_embd_head_v;
11183
11191
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11295,7 +11303,7 @@ struct llm_build_context {
11295
11303
}
11296
11304
11297
11305
struct ggml_cgraph * build_internlm2() {
11298
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
11306
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
11299
11307
11300
11308
const int64_t n_embd_head = hparams.n_embd_head_v;
11301
11309
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11416,7 +11424,7 @@ struct llm_build_context {
11416
11424
// https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738
11417
11425
// based on the original build_llama() function
11418
11426
struct ggml_cgraph * build_minicpm() {
11419
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
11427
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
11420
11428
11421
11429
const int64_t n_embd_head = hparams.n_embd_head_v;
11422
11430
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11560,7 +11568,7 @@ struct llm_build_context {
11560
11568
}
11561
11569
11562
11570
struct ggml_cgraph * build_gemma() {
11563
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
11571
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
11564
11572
11565
11573
const int64_t n_embd_head_k = hparams.n_embd_head_k;
11566
11574
@@ -11668,7 +11676,7 @@ struct llm_build_context {
11668
11676
}
11669
11677
11670
11678
struct ggml_cgraph * build_gemma2() {
11671
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
11679
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
11672
11680
11673
11681
const int64_t n_embd_head_k = hparams.n_embd_head_k;
11674
11682
@@ -11803,7 +11811,7 @@ struct llm_build_context {
11803
11811
11804
11812
11805
11813
struct ggml_cgraph * build_starcoder2() {
11806
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
11814
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
11807
11815
11808
11816
const int64_t n_embd_head = hparams.n_embd_head_v;
11809
11817
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -11922,7 +11930,7 @@ struct llm_build_context {
11922
11930
}
11923
11931
11924
11932
struct ggml_cgraph * build_mamba() {
11925
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
11933
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
11926
11934
11927
11935
const int64_t d_model = n_embd;
11928
11936
const int64_t d_conv = hparams.ssm_d_conv;
@@ -12071,7 +12079,7 @@ struct llm_build_context {
12071
12079
12072
12080
struct ggml_cgraph * build_command_r() {
12073
12081
12074
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
12082
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
12075
12083
12076
12084
const int64_t n_embd_head = hparams.n_embd_head_v;
12077
12085
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12225,7 +12233,7 @@ struct llm_build_context {
12225
12233
// * removed bias
12226
12234
// * removed MoE
12227
12235
struct ggml_cgraph * build_olmo() {
12228
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
12236
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
12229
12237
12230
12238
// mutable variable, needed during the last layer of the computation to skip unused tokens
12231
12239
int32_t n_tokens = this->n_tokens;
@@ -12349,7 +12357,7 @@ struct llm_build_context {
12349
12357
}
12350
12358
12351
12359
struct ggml_cgraph * build_openelm() {
12352
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
12360
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
12353
12361
12354
12362
const int64_t n_embd_head = hparams.n_embd_head_v;
12355
12363
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -12474,7 +12482,7 @@ struct llm_build_context {
12474
12482
}
12475
12483
12476
12484
struct ggml_cgraph * build_gptneox() {
12477
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
12485
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
12478
12486
12479
12487
const int64_t n_embd_head = hparams.n_embd_head_v;
12480
12488
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -12616,7 +12624,7 @@ struct llm_build_context {
12616
12624
}
12617
12625
12618
12626
struct ggml_cgraph * build_arctic() {
12619
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
12627
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
12620
12628
12621
12629
// mutable variable, needed during the last layer of the computation to skip unused tokens
12622
12630
int32_t n_tokens = this->n_tokens;
@@ -12748,7 +12756,7 @@ struct llm_build_context {
12748
12756
}
12749
12757
12750
12758
struct ggml_cgraph * build_deepseek2() {
12751
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
12759
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
12752
12760
12753
12761
// mutable variable, needed during the last layer of the computation to skip unused tokens
12754
12762
int32_t n_tokens = this->n_tokens;
@@ -12976,7 +12984,7 @@ struct llm_build_context {
12976
12984
}
12977
12985
12978
12986
struct ggml_cgraph * build_bitnet() {
12979
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
12987
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
12980
12988
12981
12989
const int64_t n_embd_head = hparams.n_embd_head_v;
12982
12990
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -13116,7 +13124,7 @@ struct llm_build_context {
13116
13124
}
13117
13125
13118
13126
struct ggml_cgraph * build_t5() {
13119
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
13127
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
13120
13128
13121
13129
// mutable variable, needed during the last layer of the computation to skip unused tokens
13122
13130
int32_t n_tokens = this->n_tokens;
@@ -13433,7 +13441,7 @@ struct llm_build_context {
13433
13441
}
13434
13442
13435
13443
struct ggml_cgraph * build_jais() {
13436
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
13444
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
13437
13445
13438
13446
const int64_t n_embd_head = hparams.n_embd_head_v;
13439
13447
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -13525,7 +13533,7 @@ struct llm_build_context {
13525
13533
}
13526
13534
13527
13535
struct ggml_cgraph * build_chatglm() {
13528
- struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES , false);
13536
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model) , false);
13529
13537
13530
13538
const int64_t n_embd_head = hparams.n_embd_head_v;
13531
13539
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
@@ -14870,9 +14878,9 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
14870
14878
// each move requires 6*n_layer tensors (see build_defrag)
14871
14879
// - source view, destination view, copy operation
14872
14880
// - x2 for keys and values
14873
- //const uint32_t max_moves = LLAMA_MAX_NODES /(6*n_layer);
14881
+ //const uint32_t max_moves = llama_model_max_nodes(model) /(6*n_layer);
14874
14882
// TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516
14875
- const uint32_t max_moves = (LLAMA_MAX_NODES - 2*n_layer)/(6*n_layer);
14883
+ const uint32_t max_moves = (llama_model_max_nodes(lctx.model) - 2*n_layer)/(6*n_layer);
14876
14884
14877
14885
// determine which KV cells to move where
14878
14886
//
@@ -16762,8 +16770,10 @@ struct llama_context * llama_new_context_with_model(
16762
16770
}
16763
16771
}
16764
16772
16773
+ const size_t max_nodes = llama_model_max_nodes(*model);
16774
+
16765
16775
// buffer used to store the computation graph and the tensor meta data
16766
- ctx->buf_compute_meta.resize(ggml_tensor_overhead()*LLAMA_MAX_NODES + ggml_graph_overhead_custom(LLAMA_MAX_NODES , false));
16776
+ ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes , false));
16767
16777
16768
16778
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
16769
16779
bool pipeline_parallel =
@@ -16776,7 +16786,7 @@ struct llama_context * llama_new_context_with_model(
16776
16786
// currently this is only implemented in the CUDA backend
16777
16787
pipeline_parallel = false;
16778
16788
#endif
16779
- ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), LLAMA_MAX_NODES , pipeline_parallel);
16789
+ ctx->sched = ggml_backend_sched_new(ctx->backends.data(), backend_buft.data(), ctx->backends.size(), max_nodes , pipeline_parallel);
16780
16790
16781
16791
if (pipeline_parallel) {
16782
16792
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(ctx->sched));
0 commit comments