Skip to content

Commit 047399c

Browse files
authored
[mlir][sparse] cleanup of CodegenEnv reduction API (llvm#75243)
1 parent c77cdba commit 047399c

File tree

3 files changed

+36
-26
lines changed

3 files changed

+36
-26
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
115115
SmallVector<Value> params;
116116
if (isReduc()) {
117117
params.push_back(redVal);
118-
if (redValidLexInsert)
118+
if (isValidLexInsert())
119119
params.push_back(redValidLexInsert);
120120
} else {
121-
assert(!redValidLexInsert);
121+
assert(!isValidLexInsert());
122122
}
123123
if (isExpand())
124124
params.push_back(expCount);
@@ -128,8 +128,8 @@ std::optional<Operation *> CodegenEnv::genLoopBoundary(
128128
unsigned i = 0;
129129
if (isReduc()) {
130130
updateReduc(params[i++]);
131-
if (redValidLexInsert)
132-
setValidLexInsert(params[i++]);
131+
if (isValidLexInsert())
132+
updateValidLexInsert(params[i++]);
133133
}
134134
if (isExpand())
135135
updateExpandCount(params[i++]);
@@ -235,14 +235,14 @@ void CodegenEnv::endExpand() {
235235
//===----------------------------------------------------------------------===//
236236

237237
void CodegenEnv::startReduc(ExprId exp, Value val) {
238-
assert(!isReduc() && exp != detail::kInvalidId);
238+
assert(!isReduc() && exp != detail::kInvalidId && val);
239239
redExp = exp;
240240
redVal = val;
241241
latticeMerger.setExprValue(exp, val);
242242
}
243243

244244
void CodegenEnv::updateReduc(Value val) {
245-
assert(isReduc());
245+
assert(isReduc() && val);
246246
redVal = val;
247247
latticeMerger.clearExprValue(redExp);
248248
latticeMerger.setExprValue(redExp, val);
@@ -257,13 +257,18 @@ Value CodegenEnv::endReduc() {
257257
return val;
258258
}
259259

260-
void CodegenEnv::setValidLexInsert(Value val) {
261-
assert(isReduc() && val);
260+
void CodegenEnv::startValidLexInsert(Value val) {
261+
assert(!isValidLexInsert() && isReduc() && val);
262+
redValidLexInsert = val;
263+
}
264+
265+
void CodegenEnv::updateValidLexInsert(Value val) {
266+
assert(redValidLexInsert && isReduc() && val);
262267
redValidLexInsert = val;
263268
}
264269

265-
void CodegenEnv::clearValidLexInsert() {
266-
assert(!isReduc());
270+
void CodegenEnv::endValidLexInsert() {
271+
assert(isValidLexInsert() && !isReduc());
267272
redValidLexInsert = Value();
268273
}
269274

@@ -272,7 +277,7 @@ void CodegenEnv::startCustomReduc(ExprId exp) {
272277
redCustom = exp;
273278
}
274279

275-
Value CodegenEnv::getCustomRedId() {
280+
Value CodegenEnv::getCustomRedId() const {
276281
assert(isCustomReduc());
277282
return dyn_cast<sparse_tensor::ReduceOp>(exp(redCustom).op).getIdentity();
278283
}

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,16 @@ class CodegenEnv {
150150
void updateReduc(Value val);
151151
Value getReduc() const { return redVal; }
152152
Value endReduc();
153-
void setValidLexInsert(Value val);
154-
void clearValidLexInsert();
153+
154+
void startValidLexInsert(Value val);
155+
bool isValidLexInsert() const { return redValidLexInsert != nullptr; }
156+
void updateValidLexInsert(Value val);
155157
Value getValidLexInsert() const { return redValidLexInsert; }
158+
void endValidLexInsert();
156159

157160
void startCustomReduc(ExprId exp);
158161
bool isCustomReduc() const { return redCustom != detail::kInvalidId; }
159-
Value getCustomRedId();
162+
Value getCustomRedId() const;
160163
void endCustomReduc();
161164

162165
private:

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
415415
SmallVector<Value> ivs = llvm::to_vector(llvm::drop_end(
416416
env.emitter().getLoopIVsRange(), env.getCurrentDepth() - numLoops));
417417
Value chain = env.getInsertionChain();
418-
if (!env.getValidLexInsert()) {
419-
env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
420-
} else {
418+
if (env.isValidLexInsert()) {
421419
// Generates runtime check for a valid lex during reduction,
422420
// to avoid inserting the identity value for empty reductions.
423421
// if (validLexInsert) then
@@ -438,6 +436,9 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
438436
// Value assignment.
439437
builder.setInsertionPointAfter(ifValidLexInsert);
440438
env.updateInsertionChain(ifValidLexInsert.getResult(0));
439+
} else {
440+
// Generates regular insertion chain.
441+
env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
441442
}
442443
return;
443444
}
@@ -688,12 +689,13 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
688689
env.startReduc(exp, genTensorLoad(env, builder, exp));
689690
}
690691
if (env.hasSparseOutput())
691-
env.setValidLexInsert(constantI1(builder, env.op().getLoc(), false));
692+
env.startValidLexInsert(
693+
constantI1(builder, env.op().getLoc(), false));
692694
} else {
693695
if (!env.isCustomReduc() || env.isReduc())
694696
genTensorStore(env, builder, exp, env.endReduc());
695697
if (env.hasSparseOutput())
696-
env.clearValidLexInsert();
698+
env.endValidLexInsert();
697699
}
698700
} else {
699701
// Start or end loop invariant hoisting of a tensor load.
@@ -846,9 +848,9 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
846848
if (env.isReduc()) {
847849
yields.push_back(env.getReduc());
848850
env.updateReduc(ifOp.getResult(y++));
849-
if (env.getValidLexInsert()) {
851+
if (env.isValidLexInsert()) {
850852
yields.push_back(env.getValidLexInsert());
851-
env.setValidLexInsert(ifOp.getResult(y++));
853+
env.updateValidLexInsert(ifOp.getResult(y++));
852854
}
853855
}
854856
if (env.isExpand()) {
@@ -904,7 +906,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
904906
});
905907
if (env.isReduc()) {
906908
types.push_back(env.getReduc().getType());
907-
if (env.getValidLexInsert())
909+
if (env.isValidLexInsert())
908910
types.push_back(env.getValidLexInsert().getType());
909911
}
910912
if (env.isExpand())
@@ -924,10 +926,10 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
924926
if (env.isReduc()) {
925927
operands.push_back(env.getReduc());
926928
env.updateReduc(redInput);
927-
if (env.getValidLexInsert()) {
929+
if (env.isValidLexInsert()) {
928930
// Any overlapping indices during a reduction creates a valid lex insert.
929931
operands.push_back(constantI1(builder, env.op().getLoc(), true));
930-
env.setValidLexInsert(validIns);
932+
env.updateValidLexInsert(validIns);
931933
}
932934
}
933935
if (env.isExpand()) {
@@ -1174,8 +1176,8 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
11741176
// Either a for-loop or a while-loop that iterates over a slice.
11751177
if (isSingleCond) {
11761178
// Any iteration creates a valid lex insert.
1177-
if (env.isReduc() && env.getValidLexInsert())
1178-
env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
1179+
if (env.isReduc() && env.isValidLexInsert())
1180+
env.updateValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
11791181
} else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
11801182
// End a while-loop.
11811183
finalizeWhileOp(env, rewriter, needsUniv);

0 commit comments

Comments
 (0)