@@ -2282,11 +2282,12 @@ struct test_soft_max : public test_case {
2282
2282
const ggml_type type;
2283
2283
const std::array<int64_t , 4 > ne;
2284
2284
const bool mask;
2285
+ const ggml_type m_prec;
2285
2286
const float scale;
2286
2287
const float max_bias;
2287
2288
2288
2289
std::string vars () override {
2289
- return VARS_TO_STR5 (type, ne, mask, scale, max_bias);
2290
+ return VARS_TO_STR6 (type, ne, mask, m_prec , scale, max_bias);
2290
2291
}
2291
2292
2292
2293
// the 1024 test with bias occasionally fails:
@@ -2298,9 +2299,10 @@ struct test_soft_max : public test_case {
2298
2299
test_soft_max (ggml_type type = GGML_TYPE_F32,
2299
2300
std::array<int64_t , 4 > ne = {10 , 5 , 4 , 3 },
2300
2301
bool mask = false ,
2302
+ ggml_type m_prec = GGML_TYPE_F32,
2301
2303
float scale = 1 .0f ,
2302
2304
float max_bias = 0 .0f )
2303
- : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
2305
+ : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
2304
2306
2305
2307
ggml_tensor * build_graph (ggml_context * ctx) override {
2306
2308
ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
@@ -2309,7 +2311,7 @@ struct test_soft_max : public test_case {
2309
2311
2310
2312
ggml_tensor * mask = nullptr ;
2311
2313
if (this ->mask ) {
2312
- mask = ggml_new_tensor_2d (ctx, GGML_TYPE_F32 , ne[0 ], ne[1 ]);
2314
+ mask = ggml_new_tensor_2d (ctx, m_prec , ne[0 ], ne[1 ]);
2313
2315
ggml_set_name (mask, " mask" );
2314
2316
}
2315
2317
@@ -4071,17 +4073,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4071
4073
for (float scale : {1 .0f , 0 .1f }) {
4072
4074
for (int64_t ne0 : {16 , 1024 }) {
4073
4075
for (int64_t ne1 : {16 , 1024 }) {
4074
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, scale, max_bias));
4075
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, scale, max_bias));
4076
+ if (mask) {
4077
+ for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
4078
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, m_prec, scale, max_bias));
4079
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, scale, max_bias));
4080
+ }
4081
+ } else {
4082
+ /* The precision of mask here doesn't matter as boolean mask is false */
4083
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4084
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4085
+ }
4076
4086
}
4077
4087
}
4078
4088
}
4079
4089
}
4080
4090
}
4081
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
4082
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , 0 .1f , 0 .0f ));
4083
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
4084
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 8 .0f ));
4091
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4092
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 0 .0f ));
4093
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4094
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4095
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 0 .0f ));
4096
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 8 .0f ));
4097
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 8 .0f ));
4085
4098
4086
4099
for (float max_bias : {0 .0f , 8 .0f }) {
4087
4100
for (float scale : {1 .0f , 0 .1f }) {
@@ -4217,13 +4230,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4217
4230
test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {8192 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
4218
4231
test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {3072 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
4219
4232
4220
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4221
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4222
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4223
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4224
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4225
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4226
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4233
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4234
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4235
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4236
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4237
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4238
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4239
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4227
4240
4228
4241
test_cases.emplace_back (new test_argmax (GGML_TYPE_F32, {32 , 10 , 1 , 1 }));
4229
4242
test_cases.emplace_back (new test_argmax (GGML_TYPE_F32, {1024 , 10 , 1 , 1 }));
0 commit comments