Skip to content

Commit 45d6c58

Browse files
committed
test-backend-ops: Add F16 mask test cases
1 parent 53847e4 commit 45d6c58

File tree

1 file changed

+29
-16
lines changed

1 file changed

+29
-16
lines changed

tests/test-backend-ops.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2282,11 +2282,12 @@ struct test_soft_max : public test_case {
22822282
const ggml_type type;
22832283
const std::array<int64_t, 4> ne;
22842284
const bool mask;
2285+
const ggml_type m_prec;
22852286
const float scale;
22862287
const float max_bias;
22872288

22882289
std::string vars() override {
2289-
return VARS_TO_STR5(type, ne, mask, scale, max_bias);
2290+
return VARS_TO_STR6(type, ne, mask, m_prec, scale, max_bias);
22902291
}
22912292

22922293
// the 1024 test with bias occasionally fails:
@@ -2298,9 +2299,10 @@ struct test_soft_max : public test_case {
22982299
test_soft_max(ggml_type type = GGML_TYPE_F32,
22992300
std::array<int64_t, 4> ne = {10, 5, 4, 3},
23002301
bool mask = false,
2302+
ggml_type m_prec = GGML_TYPE_F32,
23012303
float scale = 1.0f,
23022304
float max_bias = 0.0f)
2303-
: type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
2305+
: type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
23042306

23052307
ggml_tensor * build_graph(ggml_context * ctx) override {
23062308
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
@@ -2309,7 +2311,7 @@ struct test_soft_max : public test_case {
23092311

23102312
ggml_tensor * mask = nullptr;
23112313
if (this->mask) {
2312-
mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
2314+
mask = ggml_new_tensor_2d(ctx, m_prec, ne[0], ne[1]);
23132315
ggml_set_name(mask, "mask");
23142316
}
23152317

@@ -4071,17 +4073,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
40714073
for (float scale : {1.0f, 0.1f}) {
40724074
for (int64_t ne0 : {16, 1024}) {
40734075
for (int64_t ne1 : {16, 1024}) {
4074-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
4075-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
4076+
if (mask) {
4077+
for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
4078+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, scale, max_bias));
4079+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, scale, max_bias));
4080+
}
4081+
} else {
4082+
/* The precision of mask here doesn't matter as boolean mask is false */
4083+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
4084+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, scale, max_bias));
4085+
}
40764086
}
40774087
}
40784088
}
40794089
}
40804090
}
4081-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, 0.1f, 0.0f));
4082-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
4083-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 0.0f));
4084-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, 0.1f, 8.0f));
4091+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
4092+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
4093+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, 0.1f, 0.0f));
4094+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 0.0f));
4095+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 0.0f));
4096+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, 0.1f, 8.0f));
4097+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, 0.1f, 8.0f));
40854098

40864099
for (float max_bias : {0.0f, 8.0f}) {
40874100
for (float scale : {1.0f, 0.1f}) {
@@ -4217,13 +4230,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
42174230
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
42184231
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
42194232

4220-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, 1.0f, 0.0f));
4221-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, 1.0f, 0.0f));
4222-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, 1.0f, 0.0f));
4223-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, 1.0f, 0.0f));
4224-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, 1.0f, 0.0f));
4225-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, 1.0f, 0.0f));
4226-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, 1.0f, 0.0f));
4233+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4234+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4235+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4236+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4237+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4238+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
4239+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, 1.0f, 0.0f));
42274240

42284241
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
42294242
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));

0 commit comments

Comments
 (0)