Skip to content

Commit 3e5faa8

Browse files
authored
cuda : fix rope + add tests (#7452)
* cuda : fix rope pos data ggml-ci * ggml : drop mode & 1 == 1 support for ggml_rope ggml-ci * ggml : support freq_factors for f16 rope (CPU) ggml-ci * tests : add rope tests using frequency factors ggml-ci
1 parent 201cc11 commit 3e5faa8

File tree

4 files changed

+47
-20
lines changed

4 files changed

+47
-20
lines changed

ggml-cuda/rope.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,9 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
283283
const bool is_neox = mode & 2;
284284
const bool is_glm = mode & 4;
285285

286-
if (is_neox) {
287-
pos = (const int32_t *) src1_d;
286+
pos = (const int32_t *) src1_d;
288287

288+
if (is_neox) {
289289
if (src2 != nullptr) {
290290
freq_factors = (const float *) src2->data;
291291
}

ggml.c

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6245,6 +6245,8 @@ static struct ggml_tensor * ggml_rope_impl(
62456245
float xpos_base,
62466246
bool xpos_down,
62476247
bool inplace) {
6248+
GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported");
6249+
62486250
GGML_ASSERT(ggml_is_vector(b));
62496251
GGML_ASSERT(b->type == GGML_TYPE_I32);
62506252
GGML_ASSERT(a->ne[2] == b->ne[0]);
@@ -14413,7 +14415,7 @@ static void ggml_compute_forward_rope_f32(
1441314415
freq_factors = (const float *) src2->data;
1441414416
}
1441514417
} else {
14416-
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for mode 1");
14418+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
1441714419
}
1441814420

1441914421
// backward process uses inverse rotation by cos and sin.
@@ -14529,13 +14531,15 @@ static void ggml_compute_forward_rope_f32(
1452914531
}
1453014532
}
1453114533

14534+
// TODO: deduplicate f16/f32 code
1453214535
static void ggml_compute_forward_rope_f16(
1453314536
const struct ggml_compute_params * params,
1453414537
struct ggml_tensor * dst,
1453514538
const bool forward) {
1453614539

1453714540
const struct ggml_tensor * src0 = dst->src[0];
1453814541
const struct ggml_tensor * src1 = dst->src[1];
14542+
const struct ggml_tensor * src2 = dst->src[2];
1453914543

1454014544
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
1454114545
return;
@@ -14588,6 +14592,17 @@ static void ggml_compute_forward_rope_f16(
1458814592
const bool is_neox = mode & 2;
1458914593
const bool is_glm = mode & 4;
1459014594

14595+
const float * freq_factors = NULL;
14596+
if (is_neox) {
14597+
if (src2 != NULL) {
14598+
GGML_ASSERT(src2->type == GGML_TYPE_F32);
14599+
GGML_ASSERT(src2->ne[0] >= n_dims / 2);
14600+
freq_factors = (const float *) src2->data;
14601+
}
14602+
} else {
14603+
GGML_ASSERT(src2 == NULL && "TODO: freq_factors not implemented for !is_neox");
14604+
}
14605+
1459114606
// backward process uses inverse rotation by cos and sin.
1459214607
// cos and sin build a rotation matrix, where the inverse is the transpose.
1459314608
// this essentially just switches the sign of sin.
@@ -14660,10 +14675,11 @@ static void ggml_compute_forward_rope_f16(
1466014675

1466114676
// simplified from `(ib * n_dims + ic) * inv_ndims`
1466214677
float cur_rot = inv_ndims * ic - ib;
14678+
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
1466314679

1466414680
float cos_theta, sin_theta;
1466514681
rope_yarn(
14666-
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14682+
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
1466714683
&cos_theta, &sin_theta
1466814684
);
1466914685
sin_theta *= sin_sign;

ggml.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1460,7 +1460,7 @@ extern "C" {
14601460
struct ggml_tensor * b);
14611461

14621462
// rotary position embedding
1463-
// if mode & 1 == 1, skip n_past elements (DEPRECATED)
1463+
// if mode & 1 == 1, skip n_past elements (NOT SUPPORTED)
14641464
// if mode & 2 == 1, GPT-NeoX style
14651465
// if mode & 4 == 1, ChatGLM style
14661466
//

tests/test-backend-ops.cpp

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,20 +1142,22 @@ struct test_rope : public test_case {
11421142
int n_dims;
11431143
int mode;
11441144
int n_ctx;
1145+
bool ff;
11451146

11461147
std::string vars() override {
1147-
return VARS_TO_STR5(type, ne, n_dims, mode, n_ctx);
1148+
return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
11481149
}
11491150

11501151
test_rope(ggml_type type = GGML_TYPE_F32,
11511152
std::array<int64_t, 4> ne = {10, 10, 10, 1},
1152-
int n_dims = 10, int mode = 0, int n_ctx = 512)
1153-
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {}
1153+
int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
1154+
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
11541155

11551156
ggml_tensor * build_graph(ggml_context * ctx) override {
11561157
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
11571158
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
1158-
ggml_tensor * out = ggml_rope(ctx, a, pos, n_dims, mode, n_ctx);
1159+
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
1160+
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
11591161
return out;
11601162
}
11611163

@@ -1169,7 +1171,12 @@ struct test_rope : public test_case {
11691171
}
11701172
ggml_backend_tensor_set(t, data.data(), 0, ne[2] * sizeof(int));
11711173
} else {
1172-
init_tensor_uniform(t);
1174+
if (t->ne[0] == n_dims/2) {
1175+
// frequency factors in the range [0.9f, 1.1f]
1176+
init_tensor_uniform(t, 0.9f, 1.1f);
1177+
} else {
1178+
init_tensor_uniform(t);
1179+
}
11731180
}
11741181
}
11751182
}
@@ -2188,16 +2195,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
21882195
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
21892196

21902197
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2191-
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512)); // llama 7B
2192-
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512)); // llama 13B
2193-
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512)); // llama 30B
2194-
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512)); // llama 65B
2195-
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512)); // neox (falcon 7B)
2196-
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512)); // neox (falcon 7B)
2197-
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
2198-
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B)
2199-
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm)
2200-
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512)); // neox (phi-2)
2198+
// TODO: ff not supported yet for !neox
2199+
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, false)); // llama 7B
2200+
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, false)); // llama 13B
2201+
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, false)); // llama 30B
2202+
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, false)); // llama 65B
2203+
2204+
for (bool ff : {false, true}) { // freq_factors
2205+
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
2206+
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
2207+
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
2208+
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
2209+
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, ff)); // neox (stablelm)
2210+
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, ff)); // neox (phi-2)
2211+
}
22012212
}
22022213

22032214
test_cases.emplace_back(new test_concat(GGML_TYPE_F32));

0 commit comments

Comments
 (0)