Skip to content

Commit 98e68b4

Browse files
committed
tests : add non-cont unary tests
1 parent fd5ea0f commit 98e68b4

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

tests/test-backend-ops.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -642,20 +642,29 @@ struct test_case {
642642
struct test_unary : public test_case {
643643
const ggml_unary_op op;
644644
const ggml_type type;
645-
const std::array<int64_t, 4> ne;
645+
const std::array<int64_t, 4> ne_a;
646+
int v; // view (1 : non-contiguous a)
646647

647648
std::string vars() override {
648-
return VARS_TO_STR2(type, ne);
649+
return VARS_TO_STR3(type, ne_a, v);
649650
}
650651

651652
test_unary(ggml_unary_op op,
652653
ggml_type type = GGML_TYPE_F32,
653-
std::array<int64_t, 4> ne = {128, 10, 10, 10})
654-
: op(op), type(type), ne(ne) {}
654+
std::array<int64_t, 4> ne_a = {128, 10, 10, 10},
655+
int v = 0)
656+
: op(op), type(type), ne_a(ne_a), v(v) {}
655657

656658
ggml_tensor * build_graph(ggml_context * ctx) override {
657-
ggml_tensor * in = ggml_new_tensor(ctx, type, 4, ne.data());
658-
ggml_tensor * out = ggml_unary(ctx, in, op);
659+
ggml_tensor * a;
660+
if (v & 1) {
661+
auto ne = ne_a; ne[0] *= 3;
662+
a = ggml_new_tensor(ctx, type, 4, ne.data());
663+
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
664+
} else {
665+
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
666+
}
667+
ggml_tensor * out = ggml_unary(ctx, a, op);
659668
return out;
660669
}
661670

@@ -2016,9 +2025,11 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
20162025
};
20172026

20182027
// unary ops
2019-
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
2020-
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
2021-
test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }));
2028+
for (int v : {0, 1}) {
2029+
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
2030+
test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 128, 10, 10, 10 }, v));
2031+
test_cases.emplace_back(new test_unary((ggml_unary_op) op, GGML_TYPE_F32, { 7, 13, 19, 23 }, v));
2032+
}
20222033
}
20232034

20242035
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));

0 commit comments

Comments
 (0)