@@ -1142,22 +1142,25 @@ struct test_rope : public test_case {
1142
1142
int n_dims;
1143
1143
int mode;
1144
1144
int n_ctx;
1145
+ float fs; // freq_scale
1146
+ float ef; // ext_factor
1147
+ float af; // attn_factor
1145
1148
bool ff;
1146
1149
1147
1150
std::string vars () override {
1148
- return VARS_TO_STR6 (type, ne, n_dims, mode, n_ctx, ff);
1151
+ return VARS_TO_STR9 (type, ne, n_dims, mode, n_ctx, fs, ef, af , ff);
1149
1152
}
1150
1153
1151
1154
test_rope (ggml_type type = GGML_TYPE_F32,
1152
1155
std::array<int64_t , 4 > ne = {10 , 10 , 10 , 1 },
1153
- int n_dims = 10 , int mode = 0 , int n_ctx = 512 , bool ff = false )
1154
- : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
1156
+ int n_dims = 10 , int mode = 0 , int n_ctx = 512 , float fs = 1 . 0f , float ef = 0 . 0f , float af = 0 . 0f , bool ff = false )
1157
+ : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff) {}
1155
1158
1156
1159
ggml_tensor * build_graph (ggml_context * ctx) override {
1157
1160
ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
1158
1161
ggml_tensor * pos = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, ne[2 ]);
1159
1162
ggml_tensor * freq = ff ? ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_dims/2 ) : nullptr ;
1160
- ggml_tensor * out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, n_ctx, 0 , 10000 .0f , 1 . 0f , 0 . 0f , 1 . 0f , 0 .0f , 0 .0f );
1163
+ ggml_tensor * out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, n_ctx, 0 , 10000 .0f , fs, ef, af, 1 .0f , 1 .0f );
1161
1164
return out;
1162
1165
}
1163
1166
@@ -2198,20 +2201,26 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2198
2201
test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
2199
2202
test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 8 .0f ));
2200
2203
2201
- for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2202
- // TODO: ff not supported yet for !neox
2203
- test_cases.emplace_back (new test_rope (type, {128 , 32 , 10 , 1 }, 128 , 0 , 512 , false )); // llama 7B
2204
- test_cases.emplace_back (new test_rope (type, {128 , 40 , 10 , 1 }, 128 , 0 , 512 , false )); // llama 13B
2205
- test_cases.emplace_back (new test_rope (type, {128 , 52 , 10 , 1 }, 128 , 0 , 512 , false )); // llama 30B
2206
- test_cases.emplace_back (new test_rope (type, {128 , 64 , 10 , 1 }, 128 , 0 , 512 , false )); // llama 65B
2207
-
2208
- for (bool ff : {false , true }) { // freq_factors
2209
- test_cases.emplace_back (new test_rope (type, { 64 , 1 , 10 , 1 }, 64 , 2 , 512 , ff)); // neox (falcon 7B)
2210
- test_cases.emplace_back (new test_rope (type, { 64 , 71 , 10 , 1 }, 64 , 2 , 512 , ff)); // neox (falcon 7B)
2211
- test_cases.emplace_back (new test_rope (type, { 64 , 8 , 10 , 1 }, 64 , 2 , 512 , ff)); // neox (falcon 40B)
2212
- test_cases.emplace_back (new test_rope (type, { 64 , 128 , 10 , 1 }, 64 , 2 , 512 , ff)); // neox (falcon 40B)
2213
- test_cases.emplace_back (new test_rope (type, { 80 , 32 , 10 , 1 }, 20 , 2 , 512 , ff)); // neox (stablelm)
2214
- test_cases.emplace_back (new test_rope (type, { 80 , 32 , 10 , 1 }, 32 , 2 , 512 , ff)); // neox (phi-2)
2204
+ for (float fs : { 1 .0f , 1 .4245f }) {
2205
+ for (float ef : { 0 .0f , 0 .7465f }) {
2206
+ for (float af : { 1 .0f , 1 .4245f }) {
2207
+ for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2208
+ // TODO: ff not supported yet for !neox
2209
+ test_cases.emplace_back (new test_rope (type, {128 , 32 , 10 , 1 }, 128 , 0 , 512 , fs, ef, af, false )); // llama 7B
2210
+ test_cases.emplace_back (new test_rope (type, {128 , 40 , 10 , 1 }, 128 , 0 , 512 , fs, ef, af, false )); // llama 13B
2211
+ test_cases.emplace_back (new test_rope (type, {128 , 52 , 10 , 1 }, 128 , 0 , 512 , fs, ef, af, false )); // llama 30B
2212
+ test_cases.emplace_back (new test_rope (type, {128 , 64 , 10 , 1 }, 128 , 0 , 512 , fs, ef, af, false )); // llama 65B
2213
+
2214
+ for (bool ff : {false , true }) { // freq_factors
2215
+ test_cases.emplace_back (new test_rope (type, { 64 , 1 , 10 , 1 }, 64 , 2 , 512 , fs, ef, af, ff)); // neox (falcon 7B)
2216
+ test_cases.emplace_back (new test_rope (type, { 64 , 71 , 10 , 1 }, 64 , 2 , 512 , fs, ef, af, ff)); // neox (falcon 7B)
2217
+ test_cases.emplace_back (new test_rope (type, { 64 , 8 , 10 , 1 }, 64 , 2 , 512 , fs, ef, af, ff)); // neox (falcon 40B)
2218
+ test_cases.emplace_back (new test_rope (type, { 64 , 128 , 10 , 1 }, 64 , 2 , 512 , fs, ef, af, ff)); // neox (falcon 40B)
2219
+ test_cases.emplace_back (new test_rope (type, { 80 , 32 , 10 , 1 }, 20 , 2 , 512 , fs, ef, af, ff)); // neox (stablelm)
2220
+ test_cases.emplace_back (new test_rope (type, { 80 , 32 , 10 , 1 }, 32 , 2 , 512 , fs, ef, af, ff)); // neox (phi-2)
2221
+ }
2222
+ }
2223
+ }
2215
2224
}
2216
2225
}
2217
2226
0 commit comments