@@ -37,22 +37,29 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
37
37
return ss;
38
38
}
39
39
40
- bool validateConfig (const MatmulConfig &cfg, ArrayRef<uint32_t > shape) {
40
+ bool validateConfig (const MatmulConfig &cfg, ArrayRef<uint32_t > shape,
41
+ bool allowIndivisibleInnerblock, bool isVNNIMM2D) {
41
42
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
42
43
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
43
44
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
44
45
cfg.innerMostKBlock <= 0 )
45
46
return false ;
46
47
if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
47
- cfg.NBlock % cfg.innerMostNBlock != 0 ||
48
- cfg.KBlock % cfg.innerMostKBlock != 0 )
48
+ (shape[0 ] % cfg.innerMostMBlock != 0 && !allowIndivisibleInnerblock))
49
+ return false ;
50
+ if (cfg.NBlock % cfg.innerMostNBlock != 0 ||
51
+ ((shape[1 ] % cfg.innerMostNBlock != 0 ) && !allowIndivisibleInnerblock) ||
52
+ (shape[1 ] % cfg.NThreads != 0 && isVNNIMM2D &&
53
+ cfg.NBlock != cfg.innerMostNBlock ))
54
+ return false ;
55
+ if (cfg.KBlock % cfg.innerMostKBlock != 0 ||
56
+ ((shape[2 ] / cfg.KThreads % cfg.KBlock != 0 ||
57
+ shape[2 ] / cfg.KThreads % cfg.innerMostKBlock != 0 ) &&
58
+ !allowIndivisibleInnerblock))
59
+ return false ;
60
+ // KThreads will not shrink automatically
61
+ if (llvm::divideCeil (shape[2 ], cfg.KBlock ) < cfg.KThreads )
49
62
return false ;
50
- if (!shape.empty ()) {
51
- // KThreads will not shrink automatically
52
- // K is shape[2]
53
- if (llvm::divideCeil (shape[2 ], cfg.KBlock ) < cfg.KThreads )
54
- return false ;
55
- }
56
63
return true ;
57
64
}
58
65
@@ -185,7 +192,6 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
185
192
ArrayRef<uint32_t > shape,
186
193
const MatmulConfig &config,
187
194
CPUTargetDescriptionAnalysis &sysDesc) {
188
- assert (validateConfig (config, shape) && " config is invalid" );
189
195
assert (shape.size () >= 3 && " shape.size() should >= 3" );
190
196
uint32_t M = shape[0 ], N = shape[1 ];
191
197
double cost = 0 ;
@@ -367,8 +373,7 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
367
373
}
368
374
369
375
// read the config from the attributes for tuning
370
- bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
371
- ArrayRef<uint32_t > shape) {
376
+ bool readConfigFromAttrs (MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
372
377
size_t cfgItemCnt = 0 ;
373
378
for (const auto &attr : attrs) {
374
379
if (attr.getName () == " KBlock" ) {
@@ -400,17 +405,28 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
400
405
cfgItemCnt++;
401
406
}
402
407
}
403
- if (cfgItemCnt != 9 ) {
404
- LLVM_DEBUG (llvm::dbgs () << " The predefined matmul config is incomplete. "
405
- " Default matmul config will be set.\n " );
408
+ return cfgItemCnt == 9 ;
409
+ }
410
+
411
+ bool readAndValidateConfig (MatmulConfig &config,
412
+ const linalg::LinalgOp &linalgOp,
413
+ ArrayRef<uint32_t > shape,
414
+ bool allowIndivisibleInnerBlock) {
415
+ SmallVector<NamedAttribute> attrs (linalgOp->getAttrs ());
416
+ bool fullConfig = readConfigFromAttrs (config, attrs);
417
+ if (!fullConfig) {
418
+ LLVM_DEBUG (llvm::dbgs () << " Missing fields in predefined config.\n " );
406
419
return false ;
407
420
}
408
- if (validateConfig (config, shape))
409
- return true ;
410
- else {
411
- assert (0 && " config is invalid" );
421
+ bool validConfig =
422
+ validateConfig (config, shape, allowIndivisibleInnerBlock,
423
+ linalgx::isGenericPackedMatmulOp (
424
+ linalgOp, linalgx::PackingType::VNNI_MM2D));
425
+ if (!validConfig) {
426
+ LLVM_DEBUG (llvm::dbgs () << " Invalid predefined config.\n " );
412
427
return false ;
413
428
}
429
+ return true ;
414
430
}
415
431
416
432
// Analyze the workload and system description to generate the default config
@@ -494,13 +510,15 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
494
510
<< " M: " << M << " , N: " << N << " , K: " << K << " \n " );
495
511
496
512
// try to read the config from the attributes
497
- SmallVector<NamedAttribute> attrs (linalgOp-> getAttrs ());
498
- bool hasPredefinedConfig =
499
- readConfigFromAttrs (config, attrs, SmallVector< uint32_t >{M, N, K} );
513
+ bool hasValidPredefinedConfig = readAndValidateConfig (
514
+ config, linalgOp, SmallVector< uint32_t >{M, N, K},
515
+ allowIndivisibleInnerBlock );
500
516
501
517
// if there is a given config, skip the cost model
502
- if (!hasPredefinedConfig) {
503
- LLVM_DEBUG (llvm::dbgs () << " No predefined config\n " );
518
+ if (!hasValidPredefinedConfig) {
519
+ LLVM_DEBUG (
520
+ llvm::dbgs ()
521
+ << " No valid predefined config. Setting with default config.\n " );
504
522
// TODO: Could add a weight or priority for cost model
505
523
SmallVector<std::tuple<CostModelFn, std::string, double >>
506
524
costModelList = {
@@ -525,7 +543,10 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
525
543
if (!configCandidates.empty ())
526
544
config = configCandidates[0 ];
527
545
528
- assert (validateConfig (config, shape) && " config is invalid" );
546
+ assert (validateConfig (config, shape, allowIndivisibleInnerBlock,
547
+ linalgx::isGenericPackedMatmulOp (
548
+ root, linalgx::PackingType::VNNI_MM2D)) &&
549
+ " config is invalid" );
529
550
}
530
551
531
552
LLVM_DEBUG (llvm::dbgs ()
0 commit comments