Skip to content

Commit 681144e

Browse files
committed
tests : add ggml_set_rows
1 parent 8cd7b3a commit 681144e

File tree

1 file changed

+81
-0
lines changed

1 file changed

+81
-0
lines changed

tests/test-backend-ops.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,78 @@ struct test_get_rows_back : public test_case {
12131213
}
12141214
};
12151215

1216+
// GGML_OP_SET_ROWS
1217+
struct test_set_rows : public test_case {
1218+
const ggml_type type;
1219+
const int n; // cols
1220+
const int m; // rows
1221+
const int r; // rows to set
1222+
const int b0; // batch size
1223+
const int b1; // batch size
1224+
const int bs; // batch size src (for testing broadcast)
1225+
const bool v; // view (non-contiguous src1)
1226+
1227+
std::string vars() override {
1228+
return VARS_TO_STR7(type, n, m, r, b0, bs, v);
1229+
}
1230+
1231+
test_set_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, int bs = 1, bool v = false)
1232+
: type(type), n(n), m(m), r(r), b0(b), b1(3), bs(bs), v(v) {
1233+
GGML_ASSERT(b0 % bs == 0 && "b0 must be a multiple of bs");
1234+
GGML_ASSERT(r <= m && "r must be less than or equal to m");
1235+
}
1236+
1237+
ggml_tensor * build_graph(ggml_context * ctx) override {
1238+
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, n, m, b0, b1);
1239+
ggml_set_name(dst, "dst");
1240+
1241+
ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n, r, b0, b1);
1242+
ggml_set_name(src, "src");
1243+
1244+
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, bs, b1);
1245+
ggml_set_name(row_idxs, "row_idxs");
1246+
1247+
if (v) {
1248+
src = ggml_view_4d(ctx, src, n, r/2, b0, b1, src->nb[1], src->nb[2], src->nb[3], 0);
1249+
row_idxs = ggml_view_3d(ctx, row_idxs, r/2, bs, b1, row_idxs->nb[1], row_idxs->nb[2], 0);
1250+
ggml_set_name(row_idxs, "view_of_rows");
1251+
}
1252+
1253+
ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);
1254+
ggml_set_name(out, "out");
1255+
1256+
return out;
1257+
}
1258+
1259+
void initialize_tensors(ggml_context * ctx) override {
1260+
std::random_device rd;
1261+
std::default_random_engine rng(rd());
1262+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1263+
if (t->type == GGML_TYPE_I64) {
1264+
if (ggml_is_view_op(t->op)) {
1265+
continue;
1266+
}
1267+
1268+
for (int i2 = 0; i2 < t->ne[2]; i2++) {
1269+
for (int i1 = 0; i1 < t->ne[1]; i1++) {
1270+
std::vector<int64_t> data(m);
1271+
for (int i = 0; i < m; i++) {
1272+
data[i] = i;
1273+
}
1274+
std::shuffle(data.begin(), data.end(), rng);
1275+
data.resize(t->ne[0]);
1276+
1277+
const size_t offs = i1*t->nb[1] + i2*t->nb[2];
1278+
ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
1279+
}
1280+
}
1281+
} else {
1282+
init_tensor_uniform(t);
1283+
}
1284+
}
1285+
}
1286+
};
1287+
12161288
// GGML_OP_ARGMAX
12171289
struct test_argmax : public test_case {
12181290
const ggml_type type;
@@ -3984,6 +4056,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39844056
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
39854057
}
39864058

4059+
test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, 1, 8, 2, 1, 1, false));
4060+
for (ggml_type type : all_types) {
4061+
for (int b : {1, 7}) {
4062+
for (bool v : {false, true}) {
4063+
test_cases.emplace_back(new test_set_rows(type, 256, 5, 4, b, 1, v));
4064+
}
4065+
}
4066+
}
4067+
39874068
for (ggml_type type_input : {GGML_TYPE_F32}) {
39884069
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
39894070
for (int k0 : {1, 3}) {

0 commit comments

Comments
 (0)