@@ -415,9 +415,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
415
415
SmallVector<Value> ivs = llvm::to_vector (llvm::drop_end (
416
416
env.emitter ().getLoopIVsRange (), env.getCurrentDepth () - numLoops));
417
417
Value chain = env.getInsertionChain ();
418
- if (!env.getValidLexInsert ()) {
419
- env.updateInsertionChain (builder.create <InsertOp>(loc, rhs, chain, ivs));
420
- } else {
418
+ if (env.isValidLexInsert ()) {
421
419
// Generates runtime check for a valid lex during reduction,
422
420
// to avoid inserting the identity value for empty reductions.
423
421
// if (validLexInsert) then
@@ -438,6 +436,9 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
438
436
// Value assignment.
439
437
builder.setInsertionPointAfter (ifValidLexInsert);
440
438
env.updateInsertionChain (ifValidLexInsert.getResult (0 ));
439
+ } else {
440
+ // Generates regular insertion chain.
441
+ env.updateInsertionChain (builder.create <InsertOp>(loc, rhs, chain, ivs));
441
442
}
442
443
return ;
443
444
}
@@ -688,12 +689,13 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
688
689
env.startReduc (exp, genTensorLoad (env, builder, exp));
689
690
}
690
691
if (env.hasSparseOutput ())
691
- env.setValidLexInsert (constantI1 (builder, env.op ().getLoc (), false ));
692
+ env.startValidLexInsert (
693
+ constantI1 (builder, env.op ().getLoc (), false ));
692
694
} else {
693
695
if (!env.isCustomReduc () || env.isReduc ())
694
696
genTensorStore (env, builder, exp, env.endReduc ());
695
697
if (env.hasSparseOutput ())
696
- env.clearValidLexInsert ();
698
+ env.endValidLexInsert ();
697
699
}
698
700
} else {
699
701
// Start or end loop invariant hoisting of a tensor load.
@@ -846,9 +848,9 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
846
848
if (env.isReduc ()) {
847
849
yields.push_back (env.getReduc ());
848
850
env.updateReduc (ifOp.getResult (y++));
849
- if (env.getValidLexInsert ()) {
851
+ if (env.isValidLexInsert ()) {
850
852
yields.push_back (env.getValidLexInsert ());
851
- env.setValidLexInsert (ifOp.getResult (y++));
853
+ env.updateValidLexInsert (ifOp.getResult (y++));
852
854
}
853
855
}
854
856
if (env.isExpand ()) {
@@ -904,7 +906,7 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
904
906
});
905
907
if (env.isReduc ()) {
906
908
types.push_back (env.getReduc ().getType ());
907
- if (env.getValidLexInsert ())
909
+ if (env.isValidLexInsert ())
908
910
types.push_back (env.getValidLexInsert ().getType ());
909
911
}
910
912
if (env.isExpand ())
@@ -924,10 +926,10 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
924
926
if (env.isReduc ()) {
925
927
operands.push_back (env.getReduc ());
926
928
env.updateReduc (redInput);
927
- if (env.getValidLexInsert ()) {
929
+ if (env.isValidLexInsert ()) {
928
930
// Any overlapping indices during a reduction creates a valid lex insert.
929
931
operands.push_back (constantI1 (builder, env.op ().getLoc (), true ));
930
- env.setValidLexInsert (validIns);
932
+ env.updateValidLexInsert (validIns);
931
933
}
932
934
}
933
935
if (env.isExpand ()) {
@@ -1174,8 +1176,8 @@ static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1174
1176
// Either a for-loop or a while-loop that iterates over a slice.
1175
1177
if (isSingleCond) {
1176
1178
// Any iteration creates a valid lex insert.
1177
- if (env.isReduc () && env.getValidLexInsert ())
1178
- env.setValidLexInsert (constantI1 (rewriter, env.op ().getLoc (), true ));
1179
+ if (env.isReduc () && env.isValidLexInsert ())
1180
+ env.updateValidLexInsert (constantI1 (rewriter, env.op ().getLoc (), true ));
1179
1181
} else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1180
1182
// End a while-loop.
1181
1183
finalizeWhileOp (env, rewriter, needsUniv);
0 commit comments