@@ -2282,11 +2282,12 @@ struct test_soft_max : public test_case {
2282
2282
const ggml_type type;
2283
2283
const std::array<int64_t , 4 > ne;
2284
2284
const bool mask;
2285
+ const ggml_type m_prec;
2285
2286
const float scale;
2286
2287
const float max_bias;
2287
2288
2288
2289
std::string vars () override {
2289
- return VARS_TO_STR5 (type, ne, mask, scale, max_bias);
2290
+ return VARS_TO_STR6 (type, ne, mask, m_prec , scale, max_bias);
2290
2291
}
2291
2292
2292
2293
// the 1024 test with bias occasionally fails:
@@ -2298,9 +2299,10 @@ struct test_soft_max : public test_case {
2298
2299
test_soft_max (ggml_type type = GGML_TYPE_F32,
2299
2300
std::array<int64_t , 4 > ne = {10 , 5 , 4 , 3 },
2300
2301
bool mask = false ,
2302
+ ggml_type m_prec = GGML_TYPE_F32,
2301
2303
float scale = 1 .0f ,
2302
2304
float max_bias = 0 .0f )
2303
- : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
2305
+ : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
2304
2306
2305
2307
ggml_tensor * build_graph (ggml_context * ctx) override {
2306
2308
ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
@@ -2309,7 +2311,7 @@ struct test_soft_max : public test_case {
2309
2311
2310
2312
ggml_tensor * mask = nullptr ;
2311
2313
if (this ->mask ) {
2312
- mask = ggml_new_tensor_2d (ctx, GGML_TYPE_F32 , ne[0 ], ne[1 ]);
2314
+ mask = ggml_new_tensor_2d (ctx, m_prec , ne[0 ], ne[1 ]);
2313
2315
ggml_set_name (mask, " mask" );
2314
2316
}
2315
2317
@@ -4078,17 +4080,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
4078
4080
for (float scale : {1 .0f , 0 .1f }) {
4079
4081
for (int64_t ne0 : {16 , 1024 }) {
4080
4082
for (int64_t ne1 : {16 , 1024 }) {
4081
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, scale, max_bias));
4082
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, scale, max_bias));
4083
+ if (mask) {
4084
+ for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
4085
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, m_prec, scale, max_bias));
4086
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, scale, max_bias));
4087
+ }
4088
+ } else {
4089
+ /* The precision of mask here doesn't matter as boolean mask is false */
4090
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4091
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4092
+ }
4083
4093
}
4084
4094
}
4085
4095
}
4086
4096
}
4087
4097
}
4088
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
4089
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , 0 .1f , 0 .0f ));
4090
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
4091
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 8 .0f ));
4098
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4099
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 0 .0f ));
4100
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4101
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4102
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 0 .0f ));
4103
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 8 .0f ));
4104
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 8 .0f ));
4092
4105
4093
4106
for (float max_bias : {0 .0f , 8 .0f }) {
4094
4107
for (float scale : {1 .0f , 0 .1f }) {
@@ -4224,13 +4237,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
4224
4237
test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {8192 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
4225
4238
test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {3072 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
4226
4239
4227
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4228
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4229
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4230
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4231
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4232
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4233
- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4240
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4241
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4242
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4243
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4244
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4245
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4246
+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4234
4247
4235
4248
test_cases.emplace_back (new test_argmax (GGML_TYPE_F32, {32 , 10 , 1 , 1 }));
4236
4249
test_cases.emplace_back (new test_argmax (GGML_TYPE_F32, {1024 , 10 , 1 , 1 }));
0 commit comments