-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] use a common util function to query the tensor level s… #76764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) Changes…et in a lattice point. Full diff: https://github.com/llvm/llvm-project/pull/76764.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 934e1e559f44d6..7be2f30d26d8ba 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -949,94 +949,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
// Sparsifier synthesis methods (loop sequence).
//===----------------------------------------------------------------------===//
-/// Starts a loop sequence at given level. Returns true if
-/// the universal loop index must be maintained at this level.
-static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
- LoopId curr, LatSetId lts) {
- assert(!env.getLoopVar(curr));
- // Emit invariants at this loop sequence level.
- genInvariants(env, builder, exp, curr, /*isStart=*/true);
- // Emit access pattern expansion for sparse tensor output.
- genExpand(env, builder, curr, /*isStart=*/true);
- // Emit further intitialization at this loop sequence level.
- const LatPointId l0 = env.set(lts)[0];
- bool needsUniv = false;
-
- SmallVector<TensorLevel> tidLvls;
- env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
- std::optional<Level> lvl,
- LevelType lt, bool isIdxReduc) {
- assert(env.merger().loop(b) == curr);
- if (isDenseLT(lt) || isUndefLT(lt)) {
- if (tid == env.merger().getSynTensorID()) {
- // Needs loop emitter to set up loop bounds for synthetic tensor too if
- // there is a loop condition imposed on the synthetic tensor.
- tidLvls.push_back(env.makeTensorLevel(tid, env.getCurrentDepth()));
- }
- needsUniv = true;
- }
- if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt) || isIdxReduc) {
- // Only when this is a index reduction loop, can the lt be undefined.
- assert(!isUndefLT(lt) || isIdxReduc);
- // sparse/singleton levels, or a dense/sparse index reduction loop.
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
- }
- });
-
- env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
-
- // Maintain the universal index only if it is actually
- // consumed by a subsequent lattice point.
- if (needsUniv) {
- for (const LatPointId li : env.set(lts).drop_front())
- if (!env.merger().hasAnySparse(env.lat(li).simple))
- return true;
- }
- return false;
-}
-
-// Generates dense affine address for encoding.
-static void genConstantDenseAddressFromLevel(CodegenEnv &env,
- OpBuilder &builder, TensorId tid,
- Level startLvl) {
- // TODO: Handle affine expression on output tensor.
- linalg::GenericOp op = env.op();
- assert(tid < op.getNumDpsInputs());
- OpOperand *input = op.getDpsInputOperands()[tid];
- const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
- const auto enc = getSparseTensorEncoding(input->get().getType());
- if (enc) {
- const Location loc = op.getLoc();
- const TensorId tid = env.makeTensorId(input->getOperandNumber());
- const Level lvlRank = enc.getLvlRank();
- assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
- for (Level l = startLvl; l < lvlRank; l++) {
- AffineExpr lvlExpr = lvlExprs[l];
- if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
- env.emitter().genDenseAffineAddress(
- builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
- else
- return; // break on first non-dense non-constant level
- }
- }
-}
-
-// We can generate address for constant affine expression before any loops
-// starting from the first level as they do not depend on any thing.
-// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
-// levels can be determined before loops.
-static void genInitConstantDenseAddress(CodegenEnv &env,
- RewriterBase &rewriter) {
- for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
- genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
-}
-
-/// Return true if the lattices bit can be iterated by a for loop.
-static bool translateBitsToTidLvlPairs(
+static bool getAllTidLvlsInLatPoints(
CodegenEnv &env, LatPointId li, LoopId curr,
- SmallVectorImpl<TensorLevel> &tidLvls,
- SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+ llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
const BitVector &simple = env.lat(li).simple;
const TensorId outTid = env.merger().getOutTensorID();
const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
@@ -1048,7 +963,7 @@ static bool translateBitsToTidLvlPairs(
LevelType lt, bool isIdxReduc) {
if (simple[b]) {
if (isIdxReduc) {
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
return;
}
@@ -1072,10 +987,10 @@ static bool translateBitsToTidLvlPairs(
}
}
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
} else if (isDenseLT(lt) || isIdxReduc) {
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
} else {
assert(isUndefLT(lt));
linalg::GenericOp op = env.op();
@@ -1109,7 +1024,7 @@ static bool translateBitsToTidLvlPairs(
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
// to avoid putting it inside inner loops.
- affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
+ callback(env.makeTensorLevel(tid, l), exp);
}
}
}
@@ -1120,15 +1035,14 @@ static bool translateBitsToTidLvlPairs(
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
- tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
+ callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
}
if (numloopCond == 0) {
// Corner cases where the loop bound is defined by a *unused* operand, in
// this case, we just generate a dense "fake" loop by iterating over the
// synthetic tensor.
- tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
- env.getCurrentDepth()));
+ callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
numloopCond++;
}
// If we just need to one loop conditions and the conditions is not imposed on
@@ -1136,6 +1050,84 @@ static bool translateBitsToTidLvlPairs(
return numloopCond == 1 && !hasNonUnique;
}
+/// Starts a loop sequence at given level. Returns true if
+/// the universal loop index must be maintained at this level.
+static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+ LoopId curr, LatSetId lts) {
+ assert(!env.getLoopVar(curr));
+ // Emit invariants at this loop sequence level.
+ genInvariants(env, builder, exp, curr, /*isStart=*/true);
+ // Emit access pattern expansion for sparse tensor output.
+ genExpand(env, builder, curr, /*isStart=*/true);
+ // Emit further initialization at this loop sequence level.
+ const LatPointId l0 = env.set(lts)[0];
+
+ SmallVector<TensorLevel> tidLvls;
+ getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
+ tidLvls.emplace_back(tl);
+ });
+
+ env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
+
+ // Maintain the universal index only if it is actually
+ // consumed by a subsequent lattice point.
+ for (const LatPointId li : env.set(lts).drop_front())
+ if (!env.merger().hasAnySparse(env.lat(li).simple))
+ return true;
+
+ return false;
+}
+
+// Generates dense affine address for encoding.
+static void genConstantDenseAddressFromLevel(CodegenEnv &env,
+ OpBuilder &builder, TensorId tid,
+ Level startLvl) {
+ // TODO: Handle affine expression on output tensor.
+ linalg::GenericOp op = env.op();
+ assert(tid < op.getNumDpsInputs());
+ OpOperand *input = op.getDpsInputOperands()[tid];
+ const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
+ const auto enc = getSparseTensorEncoding(input->get().getType());
+ if (enc) {
+ const Location loc = op.getLoc();
+ const TensorId tid = env.makeTensorId(input->getOperandNumber());
+ const Level lvlRank = enc.getLvlRank();
+ assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
+ for (Level l = startLvl; l < lvlRank; l++) {
+ AffineExpr lvlExpr = lvlExprs[l];
+ if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+ env.emitter().genDenseAffineAddress(
+ builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
+ else
+ return; // break on first non-dense non-constant level
+ }
+ }
+}
+
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on any thing.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
+static void genInitConstantDenseAddress(CodegenEnv &env,
+ RewriterBase &rewriter) {
+ for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
+ genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
+}
+
+/// Return true if the lattices bit can be iterated by a for loop.
+static bool translateBitsToTidLvlPairs(
+ CodegenEnv &env, LatPointId li, LoopId curr,
+ SmallVectorImpl<TensorLevel> &tidLvls,
+ SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+ return getAllTidLvlsInLatPoints(env, li, curr,
+ [&](TensorLevel tl, AffineExpr exp) {
+ if (exp)
+ affineTidLvls.emplace_back(tl, exp);
+ else
+ tidLvls.emplace_back(tl);
+ });
+}
+
/// Starts a single loop in current sequence.
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
OpBuilder &builder, LoopId curr,
|
yinying-lisa-li
approved these changes
Jan 2, 2024
…et in a lattice point.
aartbik
approved these changes
Jan 2, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
…et in a lattice point.