@@ -37,7 +37,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
37
37
return ss;
38
38
}
39
39
40
- bool validateConfig (const MatmulConfig &cfg) {
40
+ bool validateConfig (const MatmulConfig &cfg, ArrayRef< uint32_t > shape = {} ) {
41
41
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
42
42
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
43
43
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
@@ -47,6 +47,12 @@ bool validateConfig(const MatmulConfig &cfg) {
47
47
cfg.NBlock % cfg.innerMostNBlock != 0 ||
48
48
cfg.KBlock % cfg.innerMostKBlock != 0 )
49
49
return false ;
50
+ if (!shape.empty ()) {
51
+ // KThreads will not shrink automatically
52
+ // K is shape[2]
53
+ if (llvm::divideCeil (shape[2 ], cfg.KBlock ) < cfg.KThreads )
54
+ return false ;
55
+ }
50
56
return true ;
51
57
}
52
58
@@ -179,7 +185,7 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
179
185
ArrayRef<uint32_t > shape,
180
186
const MatmulConfig &config,
181
187
CPUTargetDescriptionAnalysis &sysDesc) {
182
- assert (validateConfig (config) && " config is invalid" );
188
+ assert (validateConfig (config, shape ) && " config is invalid" );
183
189
assert (shape.size () >= 3 && " shape.size() should >= 3" );
184
190
uint32_t M = shape[0 ], N = shape[1 ];
185
191
double cost = 0 ;
@@ -361,7 +367,8 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
361
367
}
362
368
363
369
// read the config from the attributes for tuning
364
- bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
370
+ bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
371
+ ArrayRef<uint32_t > shape) {
365
372
size_t cfgItemCnt = 0 ;
366
373
for (const auto &attr : attrs) {
367
374
if (attr.getName () == " KBlock" ) {
@@ -393,7 +400,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
393
400
cfgItemCnt++;
394
401
}
395
402
}
396
- if (validateConfig (config)) {
403
+ if (validateConfig (config, shape )) {
397
404
return cfgItemCnt == 9 ;
398
405
} else {
399
406
LLVM_DEBUG (llvm::dbgs () << " The predefined config is invalid\n " );
@@ -483,7 +490,8 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
483
490
484
491
// try to read the config from the attributes
485
492
SmallVector<NamedAttribute> attrs (linalgOp->getAttrs ());
486
- bool hasPredefinedConfig = readConfigFromAttrs (config, attrs);
493
+ bool hasPredefinedConfig =
494
+ readConfigFromAttrs (config, attrs, SmallVector<uint32_t >{M, N, K});
487
495
488
496
// if there is a given config, skip the cost model
489
497
if (!hasPredefinedConfig) {
@@ -520,7 +528,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
520
528
hasConfig = true ;
521
529
}
522
530
523
- assert (validateConfig (config) && " config is invalid" );
524
531
return config;
525
532
}
526
533
} // namespace gc
0 commit comments