@@ -37,7 +37,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
37
37
return ss;
38
38
}
39
39
40
- bool validateConfig (const MatmulConfig &cfg) {
40
+ bool validateConfig (const MatmulConfig &cfg, ArrayRef< uint32_t > shape ) {
41
41
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
42
42
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
43
43
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
@@ -47,6 +47,12 @@ bool validateConfig(const MatmulConfig &cfg) {
47
47
cfg.NBlock % cfg.innerMostNBlock != 0 ||
48
48
cfg.KBlock % cfg.innerMostKBlock != 0 )
49
49
return false ;
50
+ if (!shape.empty ()) {
51
+ // KThreads will not shrink automatically
52
+ // K is shape[2]
53
+ if (llvm::divideCeil (shape[2 ], cfg.KBlock ) < cfg.KThreads )
54
+ return false ;
55
+ }
50
56
return true ;
51
57
}
52
58
@@ -179,7 +185,7 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
179
185
ArrayRef<uint32_t > shape,
180
186
const MatmulConfig &config,
181
187
CPUTargetDescriptionAnalysis &sysDesc) {
182
- assert (validateConfig (config) && " config is invalid" );
188
+ assert (validateConfig (config, shape ) && " config is invalid" );
183
189
assert (shape.size () >= 3 && " shape.size() should >= 3" );
184
190
uint32_t M = shape[0 ], N = shape[1 ];
185
191
double cost = 0 ;
@@ -361,7 +367,8 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
361
367
}
362
368
363
369
// read the config from the attributes for tuning
364
- bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
370
+ bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
371
+ ArrayRef<uint32_t > shape) {
365
372
size_t cfgItemCnt = 0 ;
366
373
for (const auto &attr : attrs) {
367
374
if (attr.getName () == " KBlock" ) {
@@ -393,10 +400,15 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
393
400
cfgItemCnt++;
394
401
}
395
402
}
396
- if (validateConfig (config)) {
397
- return cfgItemCnt == 9 ;
398
- } else {
399
- LLVM_DEBUG (llvm::dbgs () << " The predefined config is invalid\n " );
403
+ if (cfgItemCnt != 9 ) {
404
+ LLVM_DEBUG (llvm::dbgs () << " The predefined matmul config is incomplete. "
405
+ " Default matmul config will be set.\n " );
406
+ return false ;
407
+ }
408
+ if (validateConfig (config, shape))
409
+ return true ;
410
+ else {
411
+ assert (0 && " config is invalid" );
400
412
return false ;
401
413
}
402
414
}
@@ -483,7 +495,8 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
483
495
484
496
// try to read the config from the attributes
485
497
SmallVector<NamedAttribute> attrs (linalgOp->getAttrs ());
486
- bool hasPredefinedConfig = readConfigFromAttrs (config, attrs);
498
+ bool hasPredefinedConfig =
499
+ readConfigFromAttrs (config, attrs, SmallVector<uint32_t >{M, N, K});
487
500
488
501
// if there is a given config, skip the cost model
489
502
if (!hasPredefinedConfig) {
@@ -511,6 +524,8 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
511
524
}
512
525
if (!configCandidates.empty ())
513
526
config = configCandidates[0 ];
527
+
528
+ assert (validateConfig (config, shape) && " config is invalid" );
514
529
}
515
530
516
531
LLVM_DEBUG (llvm::dbgs ()
@@ -520,7 +535,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
520
535
hasConfig = true ;
521
536
}
522
537
523
- assert (validateConfig (config) && " config is invalid" );
524
538
return config;
525
539
}
526
540
} // namespace gc
0 commit comments