Skip to content

Commit 27a7da6

Browse files
authored
Add matmul config validity check (#376)
* Add matmul config validity check * Align constraints
1 parent 37a70e6 commit 27a7da6

File tree

2 files changed

+61
-31
lines changed

2 files changed

+61
-31
lines changed

include/gc/Analysis/MatmulConfigAnalysis.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ struct MatmulConfig {
3333
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
3434
};
3535

36+
bool validateConfig(const MatmulConfig &cfg, ArrayRef<uint32_t> shape,
37+
bool allowIndivisibleInnerblock, bool isVNNIMM2D);
38+
3639
enum DimType { Batch, M, N, K };
3740

3841
// Extract the index of the given DimType in the DimType list

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,29 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
3737
return ss;
3838
}
3939

40-
bool validateConfig(const MatmulConfig &cfg) {
40+
bool validateConfig(const MatmulConfig &cfg, ArrayRef<uint32_t> shape,
41+
bool allowIndivisibleInnerblock, bool isVNNIMM2D) {
4142
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
4243
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
4344
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
4445
cfg.innerMostKBlock <= 0)
4546
return false;
4647
if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
47-
cfg.NBlock % cfg.innerMostNBlock != 0 ||
48-
cfg.KBlock % cfg.innerMostKBlock != 0)
48+
(shape[0] % cfg.innerMostMBlock != 0 && !allowIndivisibleInnerblock))
49+
return false;
50+
if (cfg.NBlock % cfg.innerMostNBlock != 0 ||
51+
((shape[1] % cfg.innerMostNBlock != 0) && !allowIndivisibleInnerblock) ||
52+
(shape[1] % cfg.NThreads != 0 && isVNNIMM2D &&
53+
cfg.NBlock != cfg.innerMostNBlock))
54+
return false;
55+
// Require K % KBlock == 0 as brgemm dynamic bs is not supported now
56+
if (cfg.KBlock % cfg.innerMostKBlock != 0 ||
57+
((shape[2] / cfg.KThreads % cfg.KBlock != 0 ||
58+
shape[2] / cfg.KThreads % cfg.innerMostKBlock != 0) &&
59+
!allowIndivisibleInnerblock))
60+
return false;
61+
// KThreads will not shrink automatically
62+
if (llvm::divideCeil(shape[2], cfg.KBlock) < cfg.KThreads)
4963
return false;
5064
return true;
5165
}
@@ -179,21 +193,22 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
179193
ArrayRef<uint32_t> shape,
180194
const MatmulConfig &config,
181195
CPUTargetDescriptionAnalysis &sysDesc) {
182-
assert(validateConfig(config) && "config is invalid");
183196
assert(shape.size() >= 3 && "shape.size() should >= 3");
184197
uint32_t M = shape[0], N = shape[1];
185198
double cost = 0;
186199
uint32_t MNumBlockPerThread =
187200
llvm::divideCeil(M / config.innerMostMBlock, config.MThreads);
188201
uint32_t MNumInnerBlockPerBlock =
189202
llvm::divideCeil(config.MBlock, config.innerMostMBlock);
203+
assert(MNumInnerBlockPerBlock > 0 && "Invalid MNumInnerBlockPerBlock.");
190204
uint32_t MCost = MNumBlockPerThread % MNumInnerBlockPerBlock != 0 ||
191205
(M / config.innerMostNBlock % config.MThreads != 0 &&
192206
config.MBlock != config.innerMostMBlock);
193207
uint32_t NNumBlockPerThread =
194208
llvm::divideCeil(N / config.innerMostNBlock, config.NThreads);
195209
uint32_t NNumInnerBlockPerBlock =
196210
llvm::divideCeil(config.NBlock, config.innerMostNBlock);
211+
assert(NNumInnerBlockPerBlock > 0 && "Invalid NNumInnerBlockPerBlock.");
197212
uint32_t NCost = NNumBlockPerThread % NNumInnerBlockPerBlock != 0 ||
198213
(N / config.innerMostNBlock % config.NThreads != 0 &&
199214
config.NBlock != config.innerMostNBlock);
@@ -312,39 +327,28 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
312327
KBlockCandidates = innerMostKBlockCandidates;
313328
}
314329

315-
// TODO: improve via multi threading or add more constraints to restrict the
316-
// candidate size
330+
bool isVNNIMM2D =
331+
linalgx::isGenericPackedMatmulOp(root, linalgx::PackingType::VNNI_MM2D);
332+
// TODO: improve via multi threading or add more constraints to restrict
333+
// the candidate size
317334
for (uint32_t MThreads : MThreadsCandidates) {
318335
for (uint32_t NThreads : NThreadsCandidates) {
319336
for (uint32_t KThreads : KThreadsCandidates) {
320337
if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc))
321338
continue;
322339
for (uint32_t MBlock : MBlockCandidates) {
323340
for (uint32_t innerMostMBlock : innerMostMBlockCandidates) {
324-
if (MBlock % innerMostMBlock != 0 ||
325-
(shape[0] % innerMostMBlock != 0 &&
326-
!allowIndivisibleInnerblock))
327-
continue;
328341
for (uint32_t NBlock : NBlockCandidates) {
329342
for (uint32_t innerMostNBlock : innerMostNBlockCandidates) {
330-
if (NBlock % innerMostNBlock != 0 ||
331-
(shape[1] % innerMostNBlock != 0 &&
332-
!allowIndivisibleInnerblock))
333-
continue;
334343
for (uint32_t KBlock : KBlockCandidates) {
335344
for (uint32_t innerMostKBlock : innerMostKBlockCandidates) {
336-
// Require K % KBlock == 0 as dynamic bs is not supported
337-
// now
338-
if (KBlock % innerMostKBlock != 0 ||
339-
((shape[2] / KThreads % KBlock != 0 ||
340-
shape[2] / KThreads % innerMostKBlock != 0) &&
341-
!allowIndivisibleInnerblock))
342-
continue;
343345
MatmulConfig config{
344346
MThreads, NThreads, KThreads,
345347
MBlock, NBlock, KBlock,
346348
innerMostMBlock, innerMostNBlock, innerMostKBlock};
347-
configs.push_back(config);
349+
if (validateConfig(config, shape,
350+
allowIndivisibleInnerblock, isVNNIMM2D))
351+
configs.push_back(config);
348352
}
349353
}
350354
}
@@ -393,12 +397,28 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
393397
cfgItemCnt++;
394398
}
395399
}
396-
if (validateConfig(config)) {
397-
return cfgItemCnt == 9;
398-
} else {
399-
LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n");
400+
return cfgItemCnt == 9;
401+
}
402+
403+
bool readAndValidateConfig(MatmulConfig &config,
404+
const linalg::LinalgOp &linalgOp,
405+
ArrayRef<uint32_t> shape,
406+
bool allowIndivisibleInnerBlock) {
407+
SmallVector<NamedAttribute> attrs(linalgOp->getAttrs());
408+
bool fullConfig = readConfigFromAttrs(config, attrs);
409+
if (!fullConfig) {
410+
LLVM_DEBUG(llvm::dbgs() << "Missing fields in predefined config.\n");
400411
return false;
401412
}
413+
bool validConfig =
414+
validateConfig(config, shape, allowIndivisibleInnerBlock,
415+
linalgx::isGenericPackedMatmulOp(
416+
linalgOp, linalgx::PackingType::VNNI_MM2D));
417+
if (!validConfig) {
418+
LLVM_DEBUG(llvm::dbgs() << "Invalid predefined config.\n");
419+
return false;
420+
}
421+
return true;
402422
}
403423

404424
// Analyze the workload and system description to generate the default config
@@ -482,12 +502,15 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
482502
<< "M: " << M << ", N: " << N << ", K: " << K << "\n");
483503

484504
// try to read the config from the attributes
485-
SmallVector<NamedAttribute> attrs(linalgOp->getAttrs());
486-
bool hasPredefinedConfig = readConfigFromAttrs(config, attrs);
505+
bool hasValidPredefinedConfig = readAndValidateConfig(
506+
config, linalgOp, SmallVector<uint32_t>{M, N, K},
507+
allowIndivisibleInnerBlock);
487508

488509
// if there is a given config, skip the cost model
489-
if (!hasPredefinedConfig) {
490-
LLVM_DEBUG(llvm::dbgs() << "No predefined config\n");
510+
if (!hasValidPredefinedConfig) {
511+
LLVM_DEBUG(
512+
llvm::dbgs()
513+
<< "No valid predefined config. Setting with default config.\n");
491514
// TODO: Could add a weight or priority for cost model
492515
SmallVector<std::tuple<CostModelFn, std::string, double>>
493516
costModelList = {
@@ -511,6 +534,11 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
511534
}
512535
if (!configCandidates.empty())
513536
config = configCandidates[0];
537+
538+
assert(validateConfig(config, shape, allowIndivisibleInnerBlock,
539+
linalgx::isGenericPackedMatmulOp(
540+
root, linalgx::PackingType::VNNI_MM2D)) &&
541+
"config is invalid");
514542
}
515543

516544
LLVM_DEBUG(llvm::dbgs()
@@ -520,7 +548,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
520548
hasConfig = true;
521549
}
522550

523-
assert(validateConfig(config) && "config is invalid");
524551
return config;
525552
}
526553
} // namespace gc

0 commit comments

Comments
 (0)