@@ -37,15 +37,29 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
37
37
return ss;
38
38
}
39
39
40
- bool validateConfig (const MatmulConfig &cfg) {
40
+ bool validateConfig (const MatmulConfig &cfg, ArrayRef<uint32_t > shape,
41
+ bool allowIndivisibleInnerblock, bool isVNNIMM2D) {
41
42
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
42
43
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
43
44
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
44
45
cfg.innerMostKBlock <= 0 )
45
46
return false ;
46
47
if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
47
- cfg.NBlock % cfg.innerMostNBlock != 0 ||
48
- cfg.KBlock % cfg.innerMostKBlock != 0 )
48
+ (shape[0 ] % cfg.innerMostMBlock != 0 && !allowIndivisibleInnerblock))
49
+ return false ;
50
+ if (cfg.NBlock % cfg.innerMostNBlock != 0 ||
51
+ ((shape[1 ] % cfg.innerMostNBlock != 0 ) && !allowIndivisibleInnerblock) ||
52
+ (shape[1 ] % cfg.NThreads != 0 && isVNNIMM2D &&
53
+ cfg.NBlock != cfg.innerMostNBlock ))
54
+ return false ;
55
+ // Require K % KBlock == 0 as brgemm dynamic bs is not supported now
56
+ if (cfg.KBlock % cfg.innerMostKBlock != 0 ||
57
+ ((shape[2 ] / cfg.KThreads % cfg.KBlock != 0 ||
58
+ shape[2 ] / cfg.KThreads % cfg.innerMostKBlock != 0 ) &&
59
+ !allowIndivisibleInnerblock))
60
+ return false ;
61
+ // KThreads will not shrink automatically
62
+ if (llvm::divideCeil (shape[2 ], cfg.KBlock ) < cfg.KThreads )
49
63
return false ;
50
64
return true ;
51
65
}
@@ -179,21 +193,22 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
179
193
ArrayRef<uint32_t > shape,
180
194
const MatmulConfig &config,
181
195
CPUTargetDescriptionAnalysis &sysDesc) {
182
- assert (validateConfig (config) && " config is invalid" );
183
196
assert (shape.size () >= 3 && " shape.size() should >= 3" );
184
197
uint32_t M = shape[0 ], N = shape[1 ];
185
198
double cost = 0 ;
186
199
uint32_t MNumBlockPerThread =
187
200
llvm::divideCeil (M / config.innerMostMBlock , config.MThreads );
188
201
uint32_t MNumInnerBlockPerBlock =
189
202
llvm::divideCeil (config.MBlock , config.innerMostMBlock );
203
+ assert (MNumInnerBlockPerBlock > 0 && " Invalid MNumInnerBlockPerBlock." );
190
204
uint32_t MCost = MNumBlockPerThread % MNumInnerBlockPerBlock != 0 ||
191
205
(M / config.innerMostNBlock % config.MThreads != 0 &&
192
206
config.MBlock != config.innerMostMBlock );
193
207
uint32_t NNumBlockPerThread =
194
208
llvm::divideCeil (N / config.innerMostNBlock , config.NThreads );
195
209
uint32_t NNumInnerBlockPerBlock =
196
210
llvm::divideCeil (config.NBlock , config.innerMostNBlock );
211
+ assert (NNumInnerBlockPerBlock > 0 && " Invalid NNumInnerBlockPerBlock." );
197
212
uint32_t NCost = NNumBlockPerThread % NNumInnerBlockPerBlock != 0 ||
198
213
(N / config.innerMostNBlock % config.NThreads != 0 &&
199
214
config.NBlock != config.innerMostNBlock );
@@ -312,39 +327,28 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
312
327
KBlockCandidates = innerMostKBlockCandidates;
313
328
}
314
329
315
- // TODO: improve via multi threading or add more constraints to restrict the
316
- // candidate size
330
+ bool isVNNIMM2D =
331
+ linalgx::isGenericPackedMatmulOp (root, linalgx::PackingType::VNNI_MM2D);
332
+ // TODO: improve via multi threading or add more constraints to restrict
333
+ // the candidate size
317
334
for (uint32_t MThreads : MThreadsCandidates) {
318
335
for (uint32_t NThreads : NThreadsCandidates) {
319
336
for (uint32_t KThreads : KThreadsCandidates) {
320
337
if (!validateThreads ({MThreads, NThreads, KThreads}, sysDesc))
321
338
continue ;
322
339
for (uint32_t MBlock : MBlockCandidates) {
323
340
for (uint32_t innerMostMBlock : innerMostMBlockCandidates) {
324
- if (MBlock % innerMostMBlock != 0 ||
325
- (shape[0 ] % innerMostMBlock != 0 &&
326
- !allowIndivisibleInnerblock))
327
- continue ;
328
341
for (uint32_t NBlock : NBlockCandidates) {
329
342
for (uint32_t innerMostNBlock : innerMostNBlockCandidates) {
330
- if (NBlock % innerMostNBlock != 0 ||
331
- (shape[1 ] % innerMostNBlock != 0 &&
332
- !allowIndivisibleInnerblock))
333
- continue ;
334
343
for (uint32_t KBlock : KBlockCandidates) {
335
344
for (uint32_t innerMostKBlock : innerMostKBlockCandidates) {
336
- // Require K % KBlock == 0 as dynamic bs is not supported
337
- // now
338
- if (KBlock % innerMostKBlock != 0 ||
339
- ((shape[2 ] / KThreads % KBlock != 0 ||
340
- shape[2 ] / KThreads % innerMostKBlock != 0 ) &&
341
- !allowIndivisibleInnerblock))
342
- continue ;
343
345
MatmulConfig config{
344
346
MThreads, NThreads, KThreads,
345
347
MBlock, NBlock, KBlock,
346
348
innerMostMBlock, innerMostNBlock, innerMostKBlock};
347
- configs.push_back (config);
349
+ if (validateConfig (config, shape,
350
+ allowIndivisibleInnerblock, isVNNIMM2D))
351
+ configs.push_back (config);
348
352
}
349
353
}
350
354
}
@@ -393,12 +397,28 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
393
397
cfgItemCnt++;
394
398
}
395
399
}
396
- if (validateConfig (config)) {
397
- return cfgItemCnt == 9 ;
398
- } else {
399
- LLVM_DEBUG (llvm::dbgs () << " The predefined config is invalid\n " );
400
+ return cfgItemCnt == 9 ;
401
+ }
402
+
403
+ bool readAndValidateConfig (MatmulConfig &config,
404
+ const linalg::LinalgOp &linalgOp,
405
+ ArrayRef<uint32_t > shape,
406
+ bool allowIndivisibleInnerBlock) {
407
+ SmallVector<NamedAttribute> attrs (linalgOp->getAttrs ());
408
+ bool fullConfig = readConfigFromAttrs (config, attrs);
409
+ if (!fullConfig) {
410
+ LLVM_DEBUG (llvm::dbgs () << " Missing fields in predefined config.\n " );
400
411
return false ;
401
412
}
413
+ bool validConfig =
414
+ validateConfig (config, shape, allowIndivisibleInnerBlock,
415
+ linalgx::isGenericPackedMatmulOp (
416
+ linalgOp, linalgx::PackingType::VNNI_MM2D));
417
+ if (!validConfig) {
418
+ LLVM_DEBUG (llvm::dbgs () << " Invalid predefined config.\n " );
419
+ return false ;
420
+ }
421
+ return true ;
402
422
}
403
423
404
424
// Analyze the workload and system description to generate the default config
@@ -482,12 +502,15 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
482
502
<< " M: " << M << " , N: " << N << " , K: " << K << " \n " );
483
503
484
504
// try to read the config from the attributes
485
- SmallVector<NamedAttribute> attrs (linalgOp->getAttrs ());
486
- bool hasPredefinedConfig = readConfigFromAttrs (config, attrs);
505
+ bool hasValidPredefinedConfig = readAndValidateConfig (
506
+ config, linalgOp, SmallVector<uint32_t >{M, N, K},
507
+ allowIndivisibleInnerBlock);
487
508
488
509
// if there is a given config, skip the cost model
489
- if (!hasPredefinedConfig) {
490
- LLVM_DEBUG (llvm::dbgs () << " No predefined config\n " );
510
+ if (!hasValidPredefinedConfig) {
511
+ LLVM_DEBUG (
512
+ llvm::dbgs ()
513
+ << " No valid predefined config. Setting with default config.\n " );
491
514
// TODO: Could add a weight or priority for cost model
492
515
SmallVector<std::tuple<CostModelFn, std::string, double >>
493
516
costModelList = {
@@ -511,6 +534,11 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
511
534
}
512
535
if (!configCandidates.empty ())
513
536
config = configCandidates[0 ];
537
+
538
+ assert (validateConfig (config, shape, allowIndivisibleInnerBlock,
539
+ linalgx::isGenericPackedMatmulOp (
540
+ root, linalgx::PackingType::VNNI_MM2D)) &&
541
+ " config is invalid" );
514
542
}
515
543
516
544
LLVM_DEBUG (llvm::dbgs ()
@@ -520,7 +548,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
520
548
hasConfig = true ;
521
549
}
522
550
523
- assert (validateConfig (config) && " config is invalid" );
524
551
return config;
525
552
}
526
553
} // namespace gc
0 commit comments