Skip to content

Commit 98ce2de

Browse files
authored
[mlir][sparse] cleanup ldx/idx/depth/at usage (#74654)
This adds a consistent usage with `at` for everything that refers to the current loop nesting. This cleans up some redundant legacy code from when we were still using topSort inside sparsifier code.
1 parent 48ca868 commit 98ce2de

File tree

1 file changed

+69
-70
lines changed

1 file changed

+69
-70
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 69 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -44,23 +44,23 @@ using namespace mlir::sparse_tensor;
4444
// Sparsifier analysis methods.
4545
//===----------------------------------------------------------------------===//
4646

47-
/// Determines if affine expression is invariant.
48-
static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
49-
bool &isAtLoop) {
47+
/// Returns true iff affine expression is invariant. Sets the
48+
/// parameter `isAtLoop` when expression just became invariant.
49+
static bool isInvariantAffine(AffineExpr a, LoopId at, bool &isAtLoop) {
5050
switch (a.getKind()) {
5151
case AffineExprKind::DimId: {
5252
const LoopId i = cast<AffineDimExpr>(a).getPosition();
53-
if (i == ldx) {
53+
if (i + 1 == at) {
5454
isAtLoop = true;
55-
return true; // invariant at given loop
55+
return true; // becomes invariant at current loop
5656
}
57-
return i < loopDepth; // invariant when already generated
57+
return i < at; // invariant when already generated
5858
}
5959
case AffineExprKind::Add:
6060
case AffineExprKind::Mul: {
6161
auto binOp = cast<AffineBinaryOpExpr>(a);
62-
return isInvariantAffine(binOp.getLHS(), loopDepth, ldx, isAtLoop) &&
63-
isInvariantAffine(binOp.getRHS(), loopDepth, ldx, isAtLoop);
62+
return isInvariantAffine(binOp.getLHS(), at, isAtLoop) &&
63+
isInvariantAffine(binOp.getRHS(), at, isAtLoop);
6464
}
6565
default: {
6666
assert(isa<AffineConstantExpr>(a));
@@ -126,23 +126,23 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
126126
if (coefficient <= 0)
127127
return false;
128128

129-
const LoopId ldx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
130-
if (!isUndefLT(merger.getLvlType(tensor, ldx)))
129+
const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
130+
if (!isUndefLT(merger.getLvlType(tensor, idx)))
131131
return false; // used more than once, e.g., A[i][i]
132132

133133
// TODO: Generalizes the following two cases. A[i] (with trivial index
134134
// expression) can be treated as a special affine index expression. We do
135135
// not necessarily need to differentiate them.
136136
if (!isSubExp) {
137137
assert(coefficient == 1);
138-
merger.setLevelAndType(tensor, ldx, lvl, lt);
138+
merger.setLevelAndType(tensor, idx, lvl, lt);
139139
}
140140

141141
if (isSubExp) {
142142
// The current loops appears in more than one affine expressions on the
143143
// same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
144144
// used twice.
145-
if (merger.hasDependentLvl(ldx, tensor)) {
145+
if (merger.hasDependentLvl(idx, tensor)) {
146146
// TODO: This can be supported by coiterate slices if the loop idx is
147147
// appeared on affine index for different tensor, or take slice on
148148
// multiple dimensions when it is on the same tensor.
@@ -154,7 +154,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
154154
// else increase min(d0_1, d0_2).
155155
return false;
156156
}
157-
merger.setLoopDependentTensorLevel(ldx, tensor, lvl, lt, coefficient);
157+
merger.setLoopDependentTensorLevel(idx, tensor, lvl, lt, coefficient);
158158
}
159159
return true;
160160
}
@@ -613,9 +613,9 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
613613
if (kind == TensorExp::Kind::kReduce)
614614
env.startCustomReduc(e); // enter custom
615615

616-
Value v0, v1;
617616
// If either lhs/rhs is a synthetic zero, we infer the type for the zero value
618617
// based on the type of the other operand.
618+
Value v0, v1;
619619
if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
620620
env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
621621
v1 = genExp(env, rewriter, exp.children.e1);
@@ -655,21 +655,21 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
655655

656656
/// Hoists loop invariant tensor loads for which indices have been exhausted.
657657
static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
658-
LoopId ldx, bool atStart) {
658+
LoopId at, bool atStart) {
659659
if (exp == ::mlir::sparse_tensor::detail::kInvalidId)
660660
return;
661661
if (env.exp(exp).kind == TensorExp::Kind::kTensor) {
662662
// Inspect tensor indices.
663-
bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId;
664663
linalg::GenericOp op = env.op();
665664
OpOperand &t = op->getOpOperand(env.exp(exp).tensor);
666665
const auto map = op.getMatchingIndexingMap(&t);
667666
const auto stt = getSparseTensorType(t.get());
668667
const Level lvlRank = stt.getLvlRank();
669668
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
669+
bool isAtLoop = at == 0; // for scalar tensors
670670
for (Level l = 0; l < lvlRank; l++) {
671671
const AffineExpr a = map.getResult(l);
672-
if (!isInvariantAffine(a, env.getLoopDepth(), ldx, isAtLoop))
672+
if (!isInvariantAffine(a, at, /*out*/ isAtLoop))
673673
return; // still in play
674674
}
675675
// All exhausted at this level (isAtLoop denotes exactly at this LoopId).
@@ -705,8 +705,8 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
705705
env.startCustomReduc(exp); // enter custom
706706
const ExprId e0 = env.exp(exp).children.e0;
707707
const ExprId e1 = env.exp(exp).children.e1;
708-
genInvariants(env, builder, e0, ldx, atStart);
709-
genInvariants(env, builder, e1, ldx, atStart);
708+
genInvariants(env, builder, e0, at, atStart);
709+
genInvariants(env, builder, e1, at, atStart);
710710
if (env.exp(exp).kind == TensorExp::Kind::kReduce)
711711
env.endCustomReduc(); // exit custom
712712
}
@@ -782,29 +782,28 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
782782

783783
/// Whether or not the current loop being generated should be parallized (if
784784
/// possible) according to the configuration.
785-
static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
785+
static bool shouldTryParallize(CodegenEnv &env, LoopId at,
786786
ArrayRef<TensorLevel> tidLvls) {
787787
linalg::GenericOp op = env.op();
788788
auto iteratorTypes = op.getIteratorTypesArray();
789-
bool isSparse = llvm::any_of(tidLvls, [ldx, &env](TensorLevel tidLvl) {
790-
// Queries the LT based on the tensor id and loop idx, as requested by
791-
// `CodegenEnv::lt(TensorId, LoopIdx)`. The returned LT from CodegenEnv
789+
bool isSparse = llvm::any_of(tidLvls, [at, &env](TensorLevel tidLvl) {
790+
// Queries the LT based on the tensor and loop id, as requested by
791+
// `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
792792
// should be consistent with the LT indexed by <TensorId, Level>.
793-
const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
793+
const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, at);
794794
return isCompressedLT(lt) || isSingletonLT(lt);
795795
});
796-
return isParallelFor(env, isOuter, isSparse);
796+
return isParallelFor(env, /*isOuter=*/at == 0, isSparse);
797797
}
798798

799799
/// Emit a loop to coiterate over the list of tensor levels. The generated loop
800800
/// can either be a for loop or while loop depending on whether there is at most
801801
/// one sparse level in the list.
802802
static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
803-
LoopId idx, ArrayRef<TensorLevel> tidLvls,
803+
ArrayRef<TensorLevel> tidLvls,
804804
bool tryParallel, bool needsUniv) {
805805
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
806-
// Construct the while-loop with a parameter for each
807-
// index.
806+
// Construct while-loop with a parameter for each index.
808807
return env.emitter().enterCoIterationOverTensorsAtLvls(
809808
builder, env.op().getLoc(), tidLvls, reduc, tryParallel,
810809
/*genDedup=*/true, needsUniv);
@@ -817,12 +816,12 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
817816
/// singleton iteration or co-iteration over the given conjunction.
818817
static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId at,
819818
bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
820-
bool tryParallel = shouldTryParallize(env, at, at == 0, tidLvls);
821-
return genCoIteration(env, builder, at, tidLvls, tryParallel, needsUniv);
819+
bool tryParallel = shouldTryParallize(env, at, tidLvls);
820+
return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
822821
}
823822

824823
/// Generates the induction structure for a while-loop.
825-
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
824+
static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
826825
bool needsUniv) {
827826
Location loc = env.op().getLoc();
828827
// Finalize each else branch of all if statements.
@@ -862,7 +861,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
862861
}
863862

864863
/// Generates a single if-statement within a while-loop.
865-
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
864+
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId at,
866865
LatPointId p) {
867866
Location loc = env.op().getLoc();
868867
SmallVector<Type> types;
@@ -880,13 +879,13 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
880879
auto stt = getSparseTensorType(env.op().getInputs()[tid]);
881880
lt = stt.getLvlType(*lvl);
882881
}
883-
assert(ldx == env.merger().loop(b));
882+
assert(at == env.merger().loop(b));
884883
Value clause;
885884
if (isCompressedLT(lt) || isSingletonLT(lt) ||
886885
isLooseCompressedLT(lt) || is2OutOf4LT(lt)) {
887886
assert(lvl.has_value());
888887
const Value crd = env.emitter().getCoords()[tid][*lvl];
889-
const Value lvar = env.getLoopVar(ldx);
888+
const Value lvar = env.getLoopVar(at);
890889
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
891890
crd, lvar);
892891
} else {
@@ -943,12 +942,12 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
943942
/// Starts a loop sequence at given level. Returns true if
944943
/// the universal loop index must be maintained at this level.
945944
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
946-
LoopId idx, LoopId ldx, LatSetId lts) {
947-
assert(!env.getLoopVar(idx));
945+
LoopId at, LatSetId lts) {
946+
assert(!env.getLoopVar(at));
948947
// Emit invariants at this loop sequence level.
949-
genInvariants(env, builder, exp, ldx, /*atStart=*/true);
948+
genInvariants(env, builder, exp, at, /*atStart=*/true);
950949
// Emit access pattern expansion for sparse tensor output.
951-
genExpand(env, builder, idx, /*atStart=*/true);
950+
genExpand(env, builder, at, /*atStart=*/true);
952951
// Emit further intitialization at this loop sequence level.
953952
const LatPointId l0 = env.set(lts)[0];
954953
bool needsUniv = false;
@@ -957,7 +956,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
957956
env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
958957
std::optional<Level> lvl,
959958
LevelType lt, bool isIdxReduc) {
960-
assert(env.merger().loop(b) == idx);
959+
assert(env.merger().loop(b) == at);
961960
if (isDenseLT(lt) || isUndefLT(lt)) {
962961
if (tid == env.merger().getSynTensorID()) {
963962
// Needs loop emitter to set up loop bounds for synthetic tensor too if
@@ -988,6 +987,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
988987
return false;
989988
}
990989

990+
// Generates dense affine address for encoding.
991991
static void genConstantDenseAddressFromLevel(CodegenEnv &env,
992992
OpBuilder &builder, TensorId tid,
993993
Level startLvl) {
@@ -1013,30 +1013,30 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
10131013
}
10141014
}
10151015

1016+
// We can generate address for constant affine expression before any loops
1017+
// starting from the first level as they do not depend on any thing.
1018+
// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1019+
// levels can be determined before loops.
10161020
static void genInitConstantDenseAddress(CodegenEnv &env,
10171021
RewriterBase &rewriter) {
1018-
// We can generate address for constant affine expression before any loops
1019-
// starting from the first level as they do not depend on any thing.
1020-
// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1021-
// levels can be determined before loops.
10221022
for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
10231023
genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
10241024
}
10251025

10261026
/// Return true if the lattices bit can be iterated by a for loop.
10271027
static bool translateBitsToTidLvlPairs(
1028-
CodegenEnv &env, LatPointId li, LoopId ldx,
1028+
CodegenEnv &env, LatPointId li, LoopId at,
10291029
SmallVectorImpl<TensorLevel> &tidLvls,
10301030
SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
10311031
const BitVector &simple = env.lat(li).simple;
10321032
const TensorId outTid = env.merger().getOutTensorID();
1033-
const std::optional<Level> outLvl = env.merger().getLvl(outTid, ldx);
1033+
const std::optional<Level> outLvl = env.merger().getLvl(outTid, at);
10341034

10351035
unsigned numloopCond = 0;
10361036
bool hasNonUnique = false;
1037-
env.merger().foreachTensorLoopId(li, [&, ldx](TensorLoopId b, TensorId tid,
1038-
std::optional<Level> lvl,
1039-
LevelType lt, bool isIdxReduc) {
1037+
env.merger().foreachTensorLoopId(li, [&, at](TensorLoopId b, TensorId tid,
1038+
std::optional<Level> lvl,
1039+
LevelType lt, bool isIdxReduc) {
10401040
if (simple[b]) {
10411041
if (isIdxReduc) {
10421042
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
@@ -1089,11 +1089,11 @@ static bool translateBitsToTidLvlPairs(
10891089
if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
10901090
continue;
10911091

1092-
// Constant affine expression are handled in genLoop
1092+
// Constant affine expression are handled in genLoop.
10931093
if (!isa<AffineConstantExpr>(exp)) {
10941094
bool isAtLoop = false;
1095-
if (isInvariantAffine(exp, env.getLoopDepth(), ldx, isAtLoop) &&
1096-
isAtLoop) {
1095+
assert(at == env.getLoopDepth());
1096+
if (isInvariantAffine(exp, at + 1, /*out*/ isAtLoop) && isAtLoop) {
10971097
// If the compound affine is invariant and we are right at the
10981098
// level. We need to generate the address according to the
10991099
// affine expression. This is also the best place we can do it
@@ -1105,7 +1105,7 @@ static bool translateBitsToTidLvlPairs(
11051105
}
11061106
});
11071107

1108-
if (isDenseLT(env.lt(outTid, ldx))) {
1108+
if (isDenseLT(env.lt(outTid, at))) {
11091109
// Note that we generate dense indices of the output tensor
11101110
// unconditionally, since they may not appear in the lattice, but may be
11111111
// needed for linearized env.
@@ -1131,9 +1131,9 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
11311131
LatPointId li, bool needsUniv) {
11321132
// The set of tensors + lvls to generate loops on
11331133
SmallVector<TensorLevel> tidLvls;
1134+
11341135
// The set of dense tensors with non-trivial affine expression that just
1135-
// becomes invariant and the address shall now be generated at the current
1136-
// level.
1136+
// becomes invariant and the address are generated at the current level.
11371137
SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
11381138
bool isSingleCond =
11391139
translateBitsToTidLvlPairs(env, li, at, tidLvls, affineTidLvls);
@@ -1161,59 +1161,56 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
11611161

11621162
/// Ends a single loop in current sequence. Returns new values for needsUniv.
11631163
static bool endLoop(CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1164-
LoopId idx, LatPointId li, bool needsUniv,
1165-
bool isSingleCond) {
1166-
1164+
LatPointId li, bool needsUniv, bool isSingleCond) {
1165+
// Either a for-loop or a while-loop that iterates over a slice.
11671166
if (isSingleCond) {
1168-
// Either a for-loop or a while-loop that iterates over a slice.
11691167
// Any iteration creates a valid lex insert.
11701168
if (env.isReduc() && env.getValidLexInsert())
11711169
env.setValidLexInsert(constantI1(rewriter, env.op().getLoc(), true));
11721170
} else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
11731171
// End a while-loop.
1174-
finalizeWhileOp(env, rewriter, idx, needsUniv);
1172+
finalizeWhileOp(env, rewriter, needsUniv);
11751173
} else {
11761174
needsUniv = false;
11771175
}
1178-
11791176
env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
11801177
env.emitter().exitCurrentLoop(rewriter, env.op().getLoc(), reduc);
11811178
return std::nullopt;
11821179
});
1183-
11841180
return needsUniv;
11851181
}
11861182

11871183
/// Ends a loop sequence at given level.
11881184
static void endLoopSeq(CodegenEnv &env, OpBuilder &builder, unsigned exp,
1189-
unsigned idx, unsigned ldx) {
1190-
assert(!env.getLoopVar(idx));
1185+
unsigned at) {
1186+
assert(!env.getLoopVar(at));
11911187
env.emitter().exitCurrentLoopSeq(builder, env.op().getLoc());
11921188
// Unmark bookkeeping of invariants and loop index.
1193-
genInvariants(env, builder, exp, ldx, /*atStart=*/false);
1189+
genInvariants(env, builder, exp, at, /*atStart=*/false);
11941190
// Finalize access pattern expansion for sparse tensor output.
1195-
genExpand(env, builder, idx, /*atStart=*/false);
1191+
genExpand(env, builder, at, /*atStart=*/false);
11961192
}
11971193

11981194
/// Recursively generates code while computing iteration lattices in order
11991195
/// to manage the complexity of implementing co-iteration over unions
12001196
/// and intersections of sparse iterations spaces.
12011197
static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
12021198
LoopId at) {
1199+
assert(at == env.getLoopDepth());
1200+
12031201
// At each leaf, assign remaining tensor (sub)expression to output tensor.
12041202
if (at == env.getLoopNum()) {
12051203
Value rhs = genExp(env, rewriter, exp);
12061204
genTensorStore(env, rewriter, exp, rhs);
12071205
return;
12081206
}
12091207

1210-
// Construct iteration lattices for current loop index, with L0 at top.
1211-
const LoopId ldx = at == 0 ? sparse_tensor::detail::kInvalidId : at - 1;
1208+
// Construct iteration lattices for current loop index.
12121209
const LatSetId lts =
12131210
env.merger().optimizeSet(env.merger().buildLattices(exp, at));
12141211

12151212
// Start a loop sequence.
1216-
bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
1213+
bool needsUniv = startLoopSeq(env, rewriter, exp, at, lts);
12171214

12181215
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
12191216
// We cannot change this to `for (const LatPointId li : env.set(lts))`
@@ -1250,11 +1247,12 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
12501247
}
12511248

12521249
// End a loop.
1253-
needsUniv = endLoop(env, rewriter, loop, at, li, needsUniv, isSingleCond);
1250+
needsUniv = endLoop(env, rewriter, loop, at, needsUniv, isSingleCond);
12541251
}
12551252

12561253
// End a loop sequence.
1257-
endLoopSeq(env, rewriter, exp, at, ldx);
1254+
endLoopSeq(env, rewriter, exp, at);
1255+
assert(at == env.getLoopDepth());
12581256
}
12591257

12601258
/// Converts the result computed by the sparse kernel into the required form.
@@ -1309,6 +1307,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
13091307
op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
13101308
"before sparsification.");
13111309
}
1310+
13121311
// Must have been demapped as well if the generic op is sorted.
13131312
assert(!hasAnyNonIdentityOperandsOrResults(op));
13141313

0 commit comments

Comments
 (0)