@@ -1142,20 +1142,22 @@ struct test_rope : public test_case {
1142
1142
int n_dims;
1143
1143
int mode;
1144
1144
int n_ctx;
1145
+ bool ff;
1145
1146
1146
1147
std::string vars () override {
1147
- return VARS_TO_STR5 (type, ne, n_dims, mode, n_ctx);
1148
+ return VARS_TO_STR6 (type, ne, n_dims, mode, n_ctx, ff );
1148
1149
}
1149
1150
1150
1151
test_rope (ggml_type type = GGML_TYPE_F32,
1151
1152
std::array<int64_t , 4 > ne = {10 , 10 , 10 , 1 },
1152
- int n_dims = 10 , int mode = 0 , int n_ctx = 512 )
1153
- : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx) {}
1153
+ int n_dims = 10 , int mode = 0 , int n_ctx = 512 , bool ff = false )
1154
+ : type(type), ne(ne), n_dims(n_dims), mode(mode), n_ctx(n_ctx), ff(ff) {}
1154
1155
1155
1156
ggml_tensor * build_graph (ggml_context * ctx) override {
1156
1157
ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
1157
1158
ggml_tensor * pos = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, ne[2 ]);
1158
- ggml_tensor * out = ggml_rope (ctx, a, pos, n_dims, mode, n_ctx);
1159
+ ggml_tensor * freq = ff ? ggml_new_tensor_1d (ctx, GGML_TYPE_F32, n_dims/2 ) : nullptr ;
1160
+ ggml_tensor * out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, n_ctx, 0 , 10000 .0f , 1 .0f , 0 .0f , 1 .0f , 0 .0f , 0 .0f );
1159
1161
return out;
1160
1162
}
1161
1163
@@ -1169,7 +1171,12 @@ struct test_rope : public test_case {
1169
1171
}
1170
1172
ggml_backend_tensor_set (t, data.data (), 0 , ne[2 ] * sizeof (int ));
1171
1173
} else {
1172
- init_tensor_uniform (t);
1174
+ if (t->ne [0 ] == n_dims/2 ) {
1175
+ // frequency factors in the range [0.9f, 1.1f]
1176
+ init_tensor_uniform (t, 0 .9f , 1 .1f );
1177
+ } else {
1178
+ init_tensor_uniform (t);
1179
+ }
1173
1180
}
1174
1181
}
1175
1182
}
@@ -2188,16 +2195,20 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
2188
2195
test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 8 .0f ));
2189
2196
2190
2197
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
2191
- test_cases.emplace_back (new test_rope (type, {128 , 32 , 10 , 1 }, 128 , 0 , 512 )); // llama 7B
2192
- test_cases.emplace_back (new test_rope (type, {128 , 40 , 10 , 1 }, 128 , 0 , 512 )); // llama 13B
2193
- test_cases.emplace_back (new test_rope (type, {128 , 52 , 10 , 1 }, 128 , 0 , 512 )); // llama 30B
2194
- test_cases.emplace_back (new test_rope (type, {128 , 64 , 10 , 1 }, 128 , 0 , 512 )); // llama 65B
2195
- test_cases.emplace_back (new test_rope (type, { 64 , 1 , 10 , 1 }, 64 , 2 , 512 )); // neox (falcon 7B)
2196
- test_cases.emplace_back (new test_rope (type, { 64 , 71 , 10 , 1 }, 64 , 2 , 512 )); // neox (falcon 7B)
2197
- test_cases.emplace_back (new test_rope (type, { 64 , 8 , 10 , 1 }, 64 , 2 , 512 )); // neox (falcon 40B)
2198
- test_cases.emplace_back (new test_rope (type, { 64 , 128 , 10 , 1 }, 64 , 2 , 512 )); // neox (falcon 40B)
2199
- test_cases.emplace_back (new test_rope (type, { 80 , 32 , 10 , 1 }, 20 , 2 , 512 )); // neox (stablelm)
2200
- test_cases.emplace_back (new test_rope (type, { 80 , 32 , 10 , 1 }, 32 , 2 , 512 )); // neox (phi-2)
2198
+ // TODO: ff not supported yet for !neox
2199
+ test_cases.emplace_back (new test_rope (type, {128 , 32 , 10 , 1 }, 128 , 0 , 512 , false )); // llama 7B
2200
+ test_cases.emplace_back (new test_rope (type, {128 , 40 , 10 , 1 }, 128 , 0 , 512 , false )); // llama 13B
2201
+ test_cases.emplace_back (new test_rope (type, {128 , 52 , 10 , 1 }, 128 , 0 , 512 , false )); // llama 30B
2202
+ test_cases.emplace_back (new test_rope (type, {128 , 64 , 10 , 1 }, 128 , 0 , 512 , false )); // llama 65B
2203
+
2204
+ for (bool ff : {false , true }) { // freq_factors
2205
+ test_cases.emplace_back (new test_rope (type, { 64 , 1 , 10 , 1 }, 64 , 2 , 512 , ff)); // neox (falcon 7B)
2206
+ test_cases.emplace_back (new test_rope (type, { 64 , 71 , 10 , 1 }, 64 , 2 , 512 , ff)); // neox (falcon 7B)
2207
+ test_cases.emplace_back (new test_rope (type, { 64 , 8 , 10 , 1 }, 64 , 2 , 512 , ff)); // neox (falcon 40B)
2208
+ test_cases.emplace_back (new test_rope (type, { 64 , 128 , 10 , 1 }, 64 , 2 , 512 , ff)); // neox (falcon 40B)
2209
+ test_cases.emplace_back (new test_rope (type, { 80 , 32 , 10 , 1 }, 20 , 2 , 512 , ff)); // neox (stablelm)
2210
+ test_cases.emplace_back (new test_rope (type, { 80 , 32 , 10 , 1 }, 32 , 2 , 512 , ff)); // neox (phi-2)
2211
+ }
2201
2212
}
2202
2213
2203
2214
test_cases.emplace_back (new test_concat (GGML_TYPE_F32));
0 commit comments