Skip to content

Commit 9d5605f

Browse files
committed
tests : add rope tests
ggml-ci
1 parent cce3dcf commit 9d5605f

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

tests/test-backend-ops.cpp

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,22 +1142,25 @@ struct test_rope : public test_case {
11421142
int n_dims;
11431143
int mode;
11441144
int n_ctx;
1145+
float fs; // freq_scale
1146+
float ef; // ext_factor
1147+
float af; // attn_factor
11451148
bool ff;
11461149

11471150
std::string vars() override {
1148-
return VARS_TO_STR6(type, ne, n_dims, mode, n_ctx, ff);
1151+
return VARS_TO_STR9(type, ne, n_dims, mode, n_ctx, fs, ef, af, ff);
11491152
}
11501153

11511154
test_rope(ggml_type type = GGML_TYPE_F32,
11521155
std::array<int64_t, 4> ne = {10, 10, 10, 1},
1153-
int n_dims = 10, int mode = 0, int n_ctx = 512, bool ff = false)
1154-
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
1156+
int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, float ef = 0.0f, float af = 0.0f, bool ff = false)
1157+
: type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff) {}
11551158

11561159
ggml_tensor * build_graph(ggml_context * ctx) override {
11571160
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
11581161
ggml_tensor * pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne[2]);
11591162
ggml_tensor * freq = ff ? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2) : nullptr;
1160-
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f);
1163+
ggml_tensor * out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, n_ctx, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f);
11611164
return out;
11621165
}
11631166

@@ -2213,20 +2216,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22132216
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
22142217
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
22152218

2216-
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2217-
// TODO: ff not supported yet for !neox
2218-
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, false)); // llama 7B
2219-
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, false)); // llama 13B
2220-
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, false)); // llama 30B
2221-
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, false)); // llama 65B
2222-
2223-
for (bool ff : {false, true}) { // freq_factors
2224-
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
2225-
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, ff)); // neox (falcon 7B)
2226-
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
2227-
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, ff)); // neox (falcon 40B)
2228-
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, ff)); // neox (stablelm)
2229-
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, ff)); // neox (phi-2)
2219+
for (float fs : { 1.0f, 1.4245f }) {
2220+
for (float ef : { 0.0f, 0.7465f }) {
2221+
for (float af : { 1.0f, 1.4245f }) {
2222+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2223+
// TODO: ff not supported yet for !neox
2224+
test_cases.emplace_back(new test_rope(type, {128, 32, 10, 1}, 128, 0, 512, fs, ef, af, false)); // llama 7B
2225+
test_cases.emplace_back(new test_rope(type, {128, 40, 10, 1}, 128, 0, 512, fs, ef, af, false)); // llama 13B
2226+
test_cases.emplace_back(new test_rope(type, {128, 52, 10, 1}, 128, 0, 512, fs, ef, af, false)); // llama 30B
2227+
test_cases.emplace_back(new test_rope(type, {128, 64, 10, 1}, 128, 0, 512, fs, ef, af, false)); // llama 65B
2228+
2229+
for (bool ff : {false, true}) { // freq_factors
2230+
test_cases.emplace_back(new test_rope(type, { 64, 1, 10, 1}, 64, 2, 512, fs, ef, af, ff)); // neox (falcon 7B)
2231+
test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512, fs, ef, af, ff)); // neox (falcon 7B)
2232+
test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512, fs, ef, af, ff)); // neox (falcon 40B)
2233+
test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512, fs, ef, af, ff)); // neox (falcon 40B)
2234+
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 20, 2, 512, fs, ef, af, ff)); // neox (stablelm)
2235+
test_cases.emplace_back(new test_rope(type, { 80, 32, 10, 1}, 32, 2, 512, fs, ef, af, ff)); // neox (phi-2)
2236+
}
2237+
}
2238+
}
22302239
}
22312240
}
22322241

0 commit comments

Comments
 (0)