Skip to content

Commit 61c2ed4

Browse files
committed
Implement customizable RoPE
The original RoPE has pre-defined parameters theta_i = 10000^(−2(i−1)/d), for i in [1, 2, ..., d/2] Our customizable RoPE, ggml_rope_custom_inplace, uses theta_i = scale * base^(−2(i−1)/d), for i in [1, 2, ..., d/2] with the default matches the original scale = 1.0 base = 10000 The new command line arguments --rope-freq-base --rope-freq-scale set the two new RoPE parameter. Recent researches show changing these two parameters extends the context limit with minimal loss. 1. Extending Context to 8K kaiokendev https://kaiokendev.github.io/til#extending-context-to-8k 2. Extending Context Window of Large Language Models via Positional Interpolation Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian https://arxiv.org/abs/2306.15595 3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation. https://www.reddit.com/user/bloc97 https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ For the bold, try adding the following command line parameters to your favorite model: -c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5
1 parent b8c8dda commit 61c2ed4

File tree

10 files changed

+127
-28
lines changed

10 files changed

+127
-28
lines changed

examples/common.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
168168
break;
169169
}
170170
params.n_ctx = std::stoi(argv[i]);
171+
} else if (arg == "--rope-freq-base") {
172+
if (++i >= argc) {
173+
invalid_param = true;
174+
break;
175+
}
176+
params.rope_freq_base = std::stof(argv[i]);
177+
} else if (arg == "--rope-freq-scale") {
178+
if (++i >= argc) {
179+
invalid_param = true;
180+
break;
181+
}
182+
params.rope_freq_scale = std::stof(argv[i]);
171183
} else if (arg == "--memory-f32") {
172184
params.memory_f16 = false;
173185
} else if (arg == "--top-p") {
@@ -469,6 +481,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
469481
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
470482
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
471483
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
484+
fprintf(stderr, " --rope_freq_base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
485+
fprintf(stderr, " --rope_freq_scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
472486
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
473487
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
474488
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -549,6 +563,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
549563
lparams.use_mlock = params.use_mlock;
550564
lparams.logits_all = params.perplexity;
551565
lparams.embedding = params.embedding;
566+
lparams.rope_freq_base = params.rope_freq_base;
567+
lparams.rope_freq_scale = params.rope_freq_scale;
552568

553569
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
554570
if (model == NULL) {

examples/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ struct gpt_params {
3232
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
3333
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
3434
bool low_vram = 0; // if true, reduce VRAM usage at the cost of performance
35+
float rope_freq_base = 10000.0f; // RoPE base frequency
36+
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
3537

3638
// sampling parameters
3739
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

examples/main/main.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,17 @@ int main(int argc, char ** argv) {
8484
return 0;
8585
}
8686

87+
if (params.rope_freq_base != 10000.0) {
88+
fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
89+
}
90+
91+
if (params.rope_freq_scale != 1.0) {
92+
fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
93+
}
94+
8795
if (params.n_ctx > 2048) {
88-
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
89-
"expect poor results\n", __func__, params.n_ctx);
96+
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
97+
"you are on your own\n", __func__, params.n_ctx);
9098
} else if (params.n_ctx < 8) {
9199
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
92100
params.n_ctx = 8;

examples/server/server.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params,
453453
fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
454454
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
455455
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
456+
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
457+
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
456458
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
457459
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
458460
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -533,6 +535,18 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
533535
break;
534536
}
535537
params.n_ctx = std::stoi(argv[i]);
538+
} else if (arg == "--rope-freq-base") {
539+
if (++i >= argc) {
540+
invalid_param = true;
541+
break;
542+
}
543+
params.rope_freq_base = std::stof(argv[i]);
544+
} else if (arg == "--rope-freq-scale") {
545+
if (++i >= argc) {
546+
invalid_param = true;
547+
break;
548+
}
549+
params.rope_freq_scale = std::stof(argv[i]);
536550
} else if (arg == "--memory-f32" || arg == "--memory_f32") {
537551
params.memory_f16 = false;
538552
} else if (arg == "--threads" || arg == "-t") {

ggml-metal.m

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,10 @@ void ggml_metal_graph_compute(
872872

873873
const int n_past = ((int32_t *)(src1->data))[0];
874874

875+
float freq_base, freq_scale;
876+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
877+
memcpy(&freq_scale, (int32_t *) src1->date + 5, sizeof(float));
878+
875879
[encoder setComputePipelineState:ctx->pipeline_rope];
876880
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
877881
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -894,6 +898,8 @@ void ggml_metal_graph_compute(
894898
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
895899
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
896900
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
901+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
902+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
897903

898904
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
899905
} break;

ggml-metal.metal

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,17 +615,19 @@ kernel void kernel_rope(
615615
constant int & n_past,
616616
constant int & n_dims,
617617
constant int & mode,
618+
constant float & freq_base,
619+
constant float & freq_scale,
618620
uint3 tpig[[thread_position_in_grid]]) {
619621
const int64_t i3 = tpig[2];
620622
const int64_t i2 = tpig[1];
621623
const int64_t i1 = tpig[0];
622624

623625
const bool is_neox = mode & 2;
624-
const float theta_scale = pow(10000.0, -2.0f/n_dims);
626+
const float theta_scale = pow(freq_base, -2.0f/n_dims);
625627

626628
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
627629

628-
float theta = (float)p;
630+
float theta = freq_scale * (float)p;
629631

630632
if (!is_neox) {
631633
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {

ggml.c

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6778,6 +6778,8 @@ struct ggml_tensor * ggml_rope_impl(
67786778
int n_past,
67796779
int n_dims,
67806780
int mode,
6781+
float freq_base,
6782+
float freq_scale,
67816783
int n_ctx,
67826784
bool inplace) {
67836785
GGML_ASSERT(n_past >= 0);
@@ -6791,12 +6793,14 @@ struct ggml_tensor * ggml_rope_impl(
67916793

67926794
ggml_scratch_save(ctx);
67936795

6794-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6796+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
67956797

67966798
((int32_t *) b->data)[0] = n_past;
67976799
((int32_t *) b->data)[1] = n_dims;
67986800
((int32_t *) b->data)[2] = mode;
67996801
((int32_t *) b->data)[3] = n_ctx;
6802+
memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float));
6803+
memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
68006804

68016805
ggml_scratch_load(ctx);
68026806

@@ -6815,7 +6819,7 @@ struct ggml_tensor * ggml_rope(
68156819
int n_dims,
68166820
int mode,
68176821
int n_ctx) {
6818-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
6822+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
68196823
}
68206824

68216825
struct ggml_tensor * ggml_rope_inplace(
@@ -6825,7 +6829,19 @@ struct ggml_tensor * ggml_rope_inplace(
68256829
int n_dims,
68266830
int mode,
68276831
int n_ctx) {
6828-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
6832+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
6833+
}
6834+
6835+
struct ggml_tensor * ggml_rope_custom_inplace(
6836+
struct ggml_context * ctx,
6837+
struct ggml_tensor * a,
6838+
int n_past,
6839+
int n_dims,
6840+
int mode,
6841+
float freq_base,
6842+
float freq_scale,
6843+
int n_ctx) {
6844+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
68296845
}
68306846

68316847
// ggml_rope_back
@@ -12444,7 +12460,7 @@ static void ggml_compute_forward_rope_f32(
1244412460
const struct ggml_tensor * src1,
1244512461
struct ggml_tensor * dst) {
1244612462
GGML_ASSERT(src1->type == GGML_TYPE_I32);
12447-
GGML_ASSERT(ggml_nelements(src1) == 4);
12463+
GGML_ASSERT(ggml_nelements(src1) == 6);
1244812464

1244912465
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1245012466
return;
@@ -12454,6 +12470,9 @@ static void ggml_compute_forward_rope_f32(
1245412470
const int n_dims = ((int32_t *) src1->data)[1];
1245512471
const int mode = ((int32_t *) src1->data)[2];
1245612472
const int n_ctx = ((int32_t *) src1->data)[3];
12473+
float freq_base, freq_scale;
12474+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12475+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1245712476

1245812477
assert(n_past >= 0);
1245912478

@@ -12495,7 +12514,7 @@ static void ggml_compute_forward_rope_f32(
1249512514
// row index used to determine which thread to use
1249612515
int ir = 0;
1249712516

12498-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12517+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
1249912518

1250012519
const bool is_neox = mode & 2;
1250112520
const bool is_glm = mode & 4;
@@ -12507,7 +12526,7 @@ static void ggml_compute_forward_rope_f32(
1250712526
if (ir++ < ir0) continue;
1250812527
if (ir > ir1) break;
1250912528

12510-
float theta = (float)p;
12529+
float theta = freq_scale * (float)p;
1251112530

1251212531
if (is_glm) {
1251312532
theta = MIN(p, n_ctx - 2);
@@ -12584,7 +12603,7 @@ static void ggml_compute_forward_rope_f16(
1258412603
const struct ggml_tensor * src1,
1258512604
struct ggml_tensor * dst) {
1258612605
GGML_ASSERT(src1->type == GGML_TYPE_I32);
12587-
GGML_ASSERT(ggml_nelements(src1) == 4);
12606+
GGML_ASSERT(ggml_nelements(src1) == 6);
1258812607

1258912608
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1259012609
return;
@@ -12594,6 +12613,9 @@ static void ggml_compute_forward_rope_f16(
1259412613
const int n_dims = ((int32_t *) src1->data)[1];
1259512614
const int mode = ((int32_t *) src1->data)[2];
1259612615
const int n_ctx = ((int32_t *) src1->data)[3];
12616+
float freq_base, freq_scale;
12617+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12618+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1259712619

1259812620
assert(n_past >= 0);
1259912621

@@ -12635,7 +12657,7 @@ static void ggml_compute_forward_rope_f16(
1263512657
// row index used to determine which thread to use
1263612658
int ir = 0;
1263712659

12638-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12660+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
1263912661

1264012662
const bool is_neox = mode & 2;
1264112663
const bool is_glm = mode & 4;
@@ -12647,7 +12669,7 @@ static void ggml_compute_forward_rope_f16(
1264712669
if (ir++ < ir0) continue;
1264812670
if (ir > ir1) break;
1264912671

12650-
float theta = (float)p;
12672+
float theta = freq_scale * (float)p;
1265112673

1265212674
if (is_glm) {
1265312675
theta = MIN(p, n_ctx - 2);

ggml.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,17 @@ extern "C" {
10551055
int mode,
10561056
int n_ctx);
10571057

1058+
// custom RoPE, in-place, returns view(a)
1059+
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
1060+
struct ggml_context * ctx,
1061+
struct ggml_tensor * a,
1062+
int n_past,
1063+
int n_dims,
1064+
int mode,
1065+
float freq_base,
1066+
float freq_scale,
1067+
int n_ctx);
1068+
10581069
// rotary position embedding backward, i.e compute dx from dy
10591070
// a - dy
10601071
GGML_API struct ggml_tensor * ggml_rope_back(

0 commit comments

Comments
 (0)