Skip to content

Commit 7d90d57

Browse files
committed
test-backend-ops: Add F16 mask test cases
1 parent 0ed0820 commit 7d90d57

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

tests/test-backend-ops.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2282,11 +2282,12 @@ struct test_soft_max : public test_case {
22822282
const ggml_type type;
22832283
const std::array<int64_t, 4> ne;
22842284
const bool mask;
2285+
const ggml_type m_prec;
22852286
const float scale;
22862287
const float max_bias;
22872288

22882289
std::string vars() override {
2289-
return VARS_TO_STR5(type, ne, mask, scale, max_bias);
2290+
return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias);
22902291
}
22912292

22922293
// the 1024 test with bias occasionally fails:
@@ -2298,9 +2299,10 @@ struct test_soft_max : public test_case {
22982299
test_soft_max(ggml_type type = GGML_TYPE_F32,
22992300
std::array<int64_t, 4> ne = {10, 5, 4, 3},
23002301
bool mask = false,
2302+
ggml_type m_prec = GGML_TYPE_F32,
23012303
float scale = 1.0f,
23022304
float max_bias = 0.0f)
2303-
: type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
2305+
: type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
23042306

23052307
ggml_tensor * build_graph(ggml_context * ctx) override {
23062308
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
@@ -2309,7 +2311,7 @@ struct test_soft_max : public test_case {
23092311

23102312
ggml_tensor * mask = nullptr;
23112313
if (this->mask) {
2312-
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
2314+
mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]);
23132315
ggml_set_name(mask, "mask");
23142316
}
23152317

@@ -4078,17 +4080,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
40784080
for (float scale : {1.0f, 0.1f}) {
40794081
for (int64_t ne0 : {16, 1024}) {
40804082
for (int64_t ne1 : {16, 1024}) {
4081-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
4082-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
4083+
if (mask) {
4084+
for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
4085+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, scale, max_bias));
4086+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias));
4087+
}
4088+
} else {
4089+
/* The precision of mask here doesn't matter as boolean mask is false */
4090+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
4091+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
4092+
}
40834093
}
40844094
}
40854095
}
40864096
}
40874097
}
4088-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
4089-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
4090-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
4091-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
4098+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
4099+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
4100+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f));
4101+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
4102+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
4103+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 8.0f));
4104+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 8.0f));
40924105

40934106
for (float max_bias : {0.0f, 8.0f}) {
40944107
for (float scale : {1.0f, 0.1f}) {
@@ -4224,13 +4237,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
42244237
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
42254238
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
42264239

4227-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f));
4228-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f));
4229-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, 1.0f, 0.0f));
4230-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, 1.0f, 0.0f));
4231-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, 1.0f, 0.0f));
4232-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
4233-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
4240+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4241+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4242+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4243+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4244+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4245+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4246+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
42344247

42354248
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
42364249
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));

0 commit comments

Comments
 (0)