Skip to content

Commit ceabcc2

Browse files
committed
Add config validity check for KThreads
1 parent d794dc7 commit ceabcc2

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
3737
return ss;
3838
}
3939

40-
bool validateConfig(const MatmulConfig &cfg) {
40+
bool validateConfig(const MatmulConfig &cfg, ArrayRef<uint32_t> shape = {}) {
4141
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
4242
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
4343
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
@@ -47,6 +47,12 @@ bool validateConfig(const MatmulConfig &cfg) {
4747
cfg.NBlock % cfg.innerMostNBlock != 0 ||
4848
cfg.KBlock % cfg.innerMostKBlock != 0)
4949
return false;
50+
if (!shape.empty()) {
51+
// KThreads will not shrink automatically
52+
// K is shape[2]
53+
if (llvm::divideCeil(shape[2], cfg.KBlock) < cfg.KThreads)
54+
return false;
55+
}
5056
return true;
5157
}
5258

@@ -179,7 +185,7 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
179185
ArrayRef<uint32_t> shape,
180186
const MatmulConfig &config,
181187
CPUTargetDescriptionAnalysis &sysDesc) {
182-
assert(validateConfig(config) && "config is invalid");
188+
assert(validateConfig(config, shape) && "config is invalid");
183189
assert(shape.size() >= 3 && "shape.size() should >= 3");
184190
uint32_t M = shape[0], N = shape[1];
185191
double cost = 0;
@@ -361,7 +367,8 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
361367
}
362368

363369
// read the config from the attributes for tuning
364-
bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
370+
bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs,
371+
ArrayRef<uint32_t> shape) {
365372
size_t cfgItemCnt = 0;
366373
for (const auto &attr : attrs) {
367374
if (attr.getName() == "KBlock") {
@@ -393,7 +400,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
393400
cfgItemCnt++;
394401
}
395402
}
396-
if (validateConfig(config)) {
403+
if (validateConfig(config, shape)) {
397404
return cfgItemCnt == 9;
398405
} else {
399406
LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n");
@@ -483,7 +490,8 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
483490

484491
// try to read the config from the attributes
485492
SmallVector<NamedAttribute> attrs(linalgOp->getAttrs());
486-
bool hasPredefinedConfig = readConfigFromAttrs(config, attrs);
493+
bool hasPredefinedConfig =
494+
readConfigFromAttrs(config, attrs, SmallVector<uint32_t>{M, N, K});
487495

488496
// if there is a given config, skip the cost model
489497
if (!hasPredefinedConfig) {
@@ -520,7 +528,6 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
520528
hasConfig = true;
521529
}
522530

523-
assert(validateConfig(config) && "config is invalid");
524531
return config;
525532
}
526533
} // namespace gc

0 commit comments

Comments
 (0)