Skip to content

Commit ac3194d

Browse files
CISCqnixsynapse
authored andcommitted
add tests for ggml_glu_split
1 parent 94361fd commit ac3194d

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

tests/test-backend-ops.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,60 @@ struct test_glu : public test_case {
11511151
}
11521152
};
11531153

1154+
struct test_glu_split : public test_case {
1155+
const ggml_glu_op op;
1156+
const ggml_type type;
1157+
const std::array<int64_t, 4> ne_a;
1158+
int v; // view (1 : non-contiguous a)
1159+
1160+
std::string vars() override {
1161+
return VARS_TO_STR3(type, ne_a, v);
1162+
}
1163+
1164+
test_glu_split(ggml_glu_op op,
1165+
ggml_type type = GGML_TYPE_F32,
1166+
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
1167+
int v = 0)
1168+
: op(op), type(type), ne_a(ne_a), v(v) {}
1169+
1170+
ggml_tensor * build_graph(ggml_context * ctx) override {
1171+
ggml_tensor * a;
1172+
ggml_tensor * b;
1173+
if (v & 1) {
1174+
auto ne = ne_a; ne[0] *= 3;
1175+
a = ggml_new_tensor(ctx, type, 4, ne.data());
1176+
ggml_set_name(a, "a");
1177+
1178+
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
1179+
ggml_set_name(a, "view_of_a");
1180+
1181+
b = ggml_new_tensor(ctx, type, 4, ne.data());
1182+
ggml_set_name(b, "b");
1183+
1184+
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
1185+
ggml_set_name(a, "view_of_b");
1186+
} else {
1187+
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
1188+
ggml_set_name(a, "a");
1189+
1190+
b = ggml_new_tensor(ctx, type, 4, ne_a.data());
1191+
ggml_set_name(b, "b");
1192+
}
1193+
1194+
ggml_tensor * out = ggml_glu_split(ctx, a, b, op);
1195+
ggml_set_name(out, "out");
1196+
1197+
return out;
1198+
}
1199+
1200+
void initialize_tensors(ggml_context * ctx) override {
1201+
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
1202+
// test extended range of values to check for NaNs in GELU
1203+
init_tensor_uniform(t, -150.f, 150.f);
1204+
}
1205+
}
1206+
};
1207+
11541208
// GGML_OP_GET_ROWS
11551209
struct test_get_rows : public test_case {
11561210
const ggml_type type;
@@ -3986,6 +4040,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39864040
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
39874041
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
39884042
}
4043+
4044+
test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));
4045+
test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));
39894046
}
39904047
}
39914048
}

0 commit comments

Comments
 (0)