Skip to content

Commit 83dcced

Browse files
committed
Add config validity check
1 parent 37a70e6 commit 83dcced

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

include/gc/Analysis/MatmulConfigAnalysis.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ struct MatmulConfig {
3333
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
3434
};
3535

36+
bool validateConfig(const MatmulConfig &cfg, ArrayRef<uint32_t> shape = {});
37+
3638
enum DimType { Batch, M, N, K };
3739

3840
// Extract the index of the given DimType in the DimType list

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
3737
return ss;
3838
}
3939

40-
bool validateConfig(const MatmulConfig &cfg) {
40+
bool validateConfig(const MatmulConfig &cfg, ArrayRef<uint32_t> shape) {
4141
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
4242
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
4343
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
@@ -47,6 +47,12 @@ bool validateConfig(const MatmulConfig &cfg) {
4747
cfg.NBlock % cfg.innerMostNBlock != 0 ||
4848
cfg.KBlock % cfg.innerMostKBlock != 0)
4949
return false;
50+
if (!shape.empty()) {
51+
// KThreads will not shrink automatically
52+
// K is shape[2]
53+
if (llvm::divideCeil(shape[2], cfg.KBlock) < cfg.KThreads)
54+
return false;
55+
}
5056
return true;
5157
}
5258

@@ -179,7 +185,7 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
179185
ArrayRef<uint32_t> shape,
180186
const MatmulConfig &config,
181187
CPUTargetDescriptionAnalysis &sysDesc) {
182-
assert(validateConfig(config) && "config is invalid");
188+
assert(validateConfig(config, shape) && "config is invalid");
183189
assert(shape.size() >= 3 && "shape.size() should >= 3");
184190
uint32_t M = shape[0], N = shape[1];
185191
double cost = 0;
@@ -361,7 +367,8 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
361367
}
362368

363369
// read the config from the attributes for tuning
364-
bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
370+
bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
371+
ArrayRef<uint32_t> shape) {
365372
size_t cfgItemCnt = 0;
366373
for (const auto &attr : attrs) {
367374
if (attr.getName() == "KBlock") {
@@ -393,10 +400,15 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
393400
cfgItemCnt++;
394401
}
395402
}
396-
if (validateConfig(config)) {
397-
return cfgItemCnt == 9;
398-
} else {
399-
LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n");
403+
if (cfgItemCnt != 9) {
404+
LLVM_DEBUG(llvm::dbgs() << "The predefined matmul config is incomplete. "
405+
"Default matmul config will be set.\n");
406+
return false;
407+
}
408+
if (validateConfig(config, shape))
409+
return true;
410+
else {
411+
assert(0 && "config is invalid");
400412
return false;
401413
}
402414
}
@@ -483,7 +495,8 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
483495

484496
// try to read the config from the attributes
485497
SmallVector<NamedAttribute> attrs(linalgOp->getAttrs());
486-
bool hasPredefinedConfig = readConfigFromAttrs(config, attrs);
498+
bool hasPredefinedConfig =
499+
readConfigFromAttrs(config, attrs, SmallVector<uint32_t>{M, N, K});
487500

488501
// if there is a given config, skip the cost model
489502
if (!hasPredefinedConfig) {
@@ -511,6 +524,8 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
511524
}
512525
if (!configCandidates.empty())
513526
config = configCandidates[0];
527+
528+
assert(validateConfig(config, shape) && "config is invalid");
514529
}
515530

516531
LLVM_DEBUG(llvm::dbgs()
@@ -520,7 +535,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
520535
hasConfig = true;
521536
}
522537

523-
assert(validateConfig(config) && "config is invalid");
524538
return config;
525539
}
526540
} // namespace gc

0 commit comments

Comments
 (0)