Skip to content

Commit 94bbc4e

Browse files
committed
Implement customizable RoPE
The original RoPE has pre-defined parameters theta_i = 10000^(−2(i−1)/d), for i in [1, 2, ..., d/2] Our customizable RoPE, ggml_rope_custom_inplace, uses theta_i = scale * base^(−2(i−1)/d), for i in [1, 2, ..., d/2] with the default matches the original scale = 1.0 base = 10000 The new command line arguments --rope-freq-base --rope-freq-scale set the two new RoPE parameter. Recent researches show changing these two parameters extends the context limit with minimal loss. 1. Extending Context to 8K kaiokendev https://kaiokendev.github.io/til#extending-context-to-8k 2. Extending Context Window of Large Language Models via Positional Interpolation Shouyuan Chen, Sherman Wong, Liangjian Chen, Yuandong Tian https://arxiv.org/abs/2306.15595 3. NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation. https://www.reddit.com/user/bloc97 https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ For the bold, try adding the following command line parameters to your favorite model: -c 16384 --rope-freq-base 80000 --rope-freq-scale 0.5
1 parent d7d2e6a commit 94bbc4e

File tree

10 files changed

+127
-28
lines changed

10 files changed

+127
-28
lines changed

examples/common.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
168168
break;
169169
}
170170
params.n_ctx = std::stoi(argv[i]);
171+
} else if (arg == "--rope-freq-base") {
172+
if (++i >= argc) {
173+
invalid_param = true;
174+
break;
175+
}
176+
params.rope_freq_base = std::stof(argv[i]);
177+
} else if (arg == "--rope-freq-scale") {
178+
if (++i >= argc) {
179+
invalid_param = true;
180+
break;
181+
}
182+
params.rope_freq_scale = std::stof(argv[i]);
171183
} else if (arg == "--memory-f32") {
172184
params.memory_f16 = false;
173185
} else if (arg == "--top-p") {
@@ -469,6 +481,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
469481
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
470482
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
471483
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
484+
fprintf(stderr, " --rope_freq_base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
485+
fprintf(stderr, " --rope_freq_scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
472486
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
473487
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
474488
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -549,6 +563,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
549563
lparams.use_mlock = params.use_mlock;
550564
lparams.logits_all = params.perplexity;
551565
lparams.embedding = params.embedding;
566+
lparams.rope_freq_base = params.rope_freq_base;
567+
lparams.rope_freq_scale = params.rope_freq_scale;
552568

553569
llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams);
554570
if (model == NULL) {

examples/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ struct gpt_params {
3232
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
3333
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
3434
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
35+
float rope_freq_base = 10000.0f; // RoPE base frequency
36+
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
3537

3638
// sampling parameters
3739
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens

examples/main/main.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,17 @@ int main(int argc, char ** argv) {
8484
return 0;
8585
}
8686

87+
if (params.rope_freq_base != 10000.0) {
88+
fprintf(stderr, "%s: warning: changing RoPE frequency base to %g (default 10000.0)\n", __func__, params.rope_freq_base);
89+
}
90+
91+
if (params.rope_freq_scale != 1.0) {
92+
fprintf(stderr, "%s: warning: scaling RoPE frequency by %g (default 1.0)\n", __func__, params.rope_freq_scale);
93+
}
94+
8795
if (params.n_ctx > 2048) {
88-
fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);"
89-
"expect poor results\n", __func__, params.n_ctx);
96+
fprintf(stderr, "%s: warning: base model only supports context sizes no greater than 2048 tokens (%d specified);"
97+
"you are on your own\n", __func__, params.n_ctx);
9098
} else if (params.n_ctx < 8) {
9199
fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__);
92100
params.n_ctx = 8;

examples/server/server.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,8 @@ static void server_print_usage(const char * argv0, const gpt_params & params,
513513
fprintf(stderr, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
514514
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
515515
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
516+
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
517+
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
516518
fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
517519
fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
518520
fprintf(stderr, " not recommended: doubles context memory required and no measurable increase in quality\n");
@@ -593,6 +595,18 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
593595
break;
594596
}
595597
params.n_ctx = std::stoi(argv[i]);
598+
} else if (arg == "--rope-freq-base") {
599+
if (++i >= argc) {
600+
invalid_param = true;
601+
break;
602+
}
603+
params.rope_freq_base = std::stof(argv[i]);
604+
} else if (arg == "--rope-freq-scale") {
605+
if (++i >= argc) {
606+
invalid_param = true;
607+
break;
608+
}
609+
params.rope_freq_scale = std::stof(argv[i]);
596610
} else if (arg == "--memory-f32" || arg == "--memory_f32") {
597611
params.memory_f16 = false;
598612
} else if (arg == "--threads" || arg == "-t") {

ggml-metal.m

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,10 @@ void ggml_metal_graph_compute(
874874

875875
const int n_past = ((int32_t *)(src1->data))[0];
876876

877+
float freq_base, freq_scale;
878+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
879+
memcpy(&freq_scale, (int32_t *) src1->date + 5, sizeof(float));
880+
877881
[encoder setComputePipelineState:ctx->pipeline_rope];
878882
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
879883
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
@@ -896,6 +900,8 @@ void ggml_metal_graph_compute(
896900
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
897901
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
898902
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
903+
[encoder setBytes:&freq_base length:sizeof(float) atIndex:21];
904+
[encoder setBytes:&freq_scale length:sizeof(float) atIndex:22];
899905

900906
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
901907
} break;

ggml-metal.metal

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,17 +615,19 @@ kernel void kernel_rope(
615615
constant int & n_past,
616616
constant int & n_dims,
617617
constant int & mode,
618+
constant float & freq_base,
619+
constant float & freq_scale,
618620
uint3 tpig[[thread_position_in_grid]]) {
619621
const int64_t i3 = tpig[2];
620622
const int64_t i2 = tpig[1];
621623
const int64_t i1 = tpig[0];
622624

623625
const bool is_neox = mode & 2;
624-
const float theta_scale = pow(10000.0, -2.0f/n_dims);
626+
const float theta_scale = pow(freq_base, -2.0f/n_dims);
625627

626628
const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
627629

628-
float theta = (float)p;
630+
float theta = freq_scale * (float)p;
629631

630632
if (!is_neox) {
631633
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {

ggml.c

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6815,6 +6815,8 @@ struct ggml_tensor * ggml_rope_impl(
68156815
int n_past,
68166816
int n_dims,
68176817
int mode,
6818+
float freq_base,
6819+
float freq_scale,
68186820
int n_ctx,
68196821
bool inplace) {
68206822
GGML_ASSERT(n_past >= 0);
@@ -6828,12 +6830,14 @@ struct ggml_tensor * ggml_rope_impl(
68286830

68296831
ggml_scratch_save(ctx);
68306832

6831-
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
6833+
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 6);
68326834

68336835
((int32_t *) b->data)[0] = n_past;
68346836
((int32_t *) b->data)[1] = n_dims;
68356837
((int32_t *) b->data)[2] = mode;
68366838
((int32_t *) b->data)[3] = n_ctx;
6839+
memcpy((int32_t *) b->data + 4, &freq_base, sizeof(float));
6840+
memcpy((int32_t *) b->data + 5, &freq_scale, sizeof(float));
68376841

68386842
ggml_scratch_load(ctx);
68396843

@@ -6852,7 +6856,7 @@ struct ggml_tensor * ggml_rope(
68526856
int n_dims,
68536857
int mode,
68546858
int n_ctx) {
6855-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, false);
6859+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
68566860
}
68576861

68586862
struct ggml_tensor * ggml_rope_inplace(
@@ -6862,7 +6866,19 @@ struct ggml_tensor * ggml_rope_inplace(
68626866
int n_dims,
68636867
int mode,
68646868
int n_ctx) {
6865-
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, true);
6869+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
6870+
}
6871+
6872+
struct ggml_tensor * ggml_rope_custom_inplace(
6873+
struct ggml_context * ctx,
6874+
struct ggml_tensor * a,
6875+
int n_past,
6876+
int n_dims,
6877+
int mode,
6878+
float freq_base,
6879+
float freq_scale,
6880+
int n_ctx) {
6881+
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
68666882
}
68676883

68686884
// ggml_rope_back
@@ -12481,7 +12497,7 @@ static void ggml_compute_forward_rope_f32(
1248112497
const struct ggml_tensor * src1,
1248212498
struct ggml_tensor * dst) {
1248312499
GGML_ASSERT(src1->type == GGML_TYPE_I32);
12484-
GGML_ASSERT(ggml_nelements(src1) == 4);
12500+
GGML_ASSERT(ggml_nelements(src1) == 6);
1248512501

1248612502
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1248712503
return;
@@ -12491,6 +12507,9 @@ static void ggml_compute_forward_rope_f32(
1249112507
const int n_dims = ((int32_t *) src1->data)[1];
1249212508
const int mode = ((int32_t *) src1->data)[2];
1249312509
const int n_ctx = ((int32_t *) src1->data)[3];
12510+
float freq_base, freq_scale;
12511+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12512+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1249412513

1249512514
assert(n_past >= 0);
1249612515

@@ -12532,7 +12551,7 @@ static void ggml_compute_forward_rope_f32(
1253212551
// row index used to determine which thread to use
1253312552
int ir = 0;
1253412553

12535-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12554+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
1253612555

1253712556
const bool is_neox = mode & 2;
1253812557
const bool is_glm = mode & 4;
@@ -12544,7 +12563,7 @@ static void ggml_compute_forward_rope_f32(
1254412563
if (ir++ < ir0) continue;
1254512564
if (ir > ir1) break;
1254612565

12547-
float theta = (float)p;
12566+
float theta = freq_scale * (float)p;
1254812567

1254912568
if (is_glm) {
1255012569
theta = MIN(p, n_ctx - 2);
@@ -12621,7 +12640,7 @@ static void ggml_compute_forward_rope_f16(
1262112640
const struct ggml_tensor * src1,
1262212641
struct ggml_tensor * dst) {
1262312642
GGML_ASSERT(src1->type == GGML_TYPE_I32);
12624-
GGML_ASSERT(ggml_nelements(src1) == 4);
12643+
GGML_ASSERT(ggml_nelements(src1) == 6);
1262512644

1262612645
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
1262712646
return;
@@ -12631,6 +12650,9 @@ static void ggml_compute_forward_rope_f16(
1263112650
const int n_dims = ((int32_t *) src1->data)[1];
1263212651
const int mode = ((int32_t *) src1->data)[2];
1263312652
const int n_ctx = ((int32_t *) src1->data)[3];
12653+
float freq_base, freq_scale;
12654+
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
12655+
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
1263412656

1263512657
assert(n_past >= 0);
1263612658

@@ -12672,7 +12694,7 @@ static void ggml_compute_forward_rope_f16(
1267212694
// row index used to determine which thread to use
1267312695
int ir = 0;
1267412696

12675-
const float theta_scale = powf(10000.0, -2.0f/n_dims);
12697+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
1267612698

1267712699
const bool is_neox = mode & 2;
1267812700
const bool is_glm = mode & 4;
@@ -12684,7 +12706,7 @@ static void ggml_compute_forward_rope_f16(
1268412706
if (ir++ < ir0) continue;
1268512707
if (ir > ir1) break;
1268612708

12687-
float theta = (float)p;
12709+
float theta = freq_scale * (float)p;
1268812710

1268912711
if (is_glm) {
1269012712
theta = MIN(p, n_ctx - 2);

ggml.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,6 +1058,17 @@ extern "C" {
10581058
int mode,
10591059
int n_ctx);
10601060

1061+
// custom RoPE, in-place, returns view(a)
1062+
GGML_API struct ggml_tensor * ggml_rope_custom_inplace(
1063+
struct ggml_context * ctx,
1064+
struct ggml_tensor * a,
1065+
int n_past,
1066+
int n_dims,
1067+
int mode,
1068+
float freq_base,
1069+
float freq_scale,
1070+
int n_ctx);
1071+
10611072
// rotary position embedding backward, i.e compute dx from dy
10621073
// a - dy
10631074
GGML_API struct ggml_tensor * ggml_rope_back(

0 commit comments

Comments
 (0)