Skip to content

[mlir][sparse] use a common util function to query the tensor level s… #76764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 86 additions & 94 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,94 +949,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
// Sparsifier synthesis methods (loop sequence).
//===----------------------------------------------------------------------===//

/// Starts a loop sequence at given level. Returns true if
/// the universal loop index must be maintained at this level.
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
LoopId curr, LatSetId lts) {
assert(!env.getLoopVar(curr));
// Emit invariants at this loop sequence level.
genInvariants(env, builder, exp, curr, /*isStart=*/true);
// Emit access pattern expansion for sparse tensor output.
genExpand(env, builder, curr, /*isStart=*/true);
// Emit further intitialization at this loop sequence level.
const LatPointId l0 = env.set(lts)[0];
bool needsUniv = false;

SmallVector<TensorLevel> tidLvls;
env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
std::optional<Level> lvl,
LevelType lt, bool isIdxReduc) {
assert(env.merger().loop(b) == curr);
if (isDenseLT(lt) || isUndefLT(lt)) {
if (tid == env.merger().getSynTensorID()) {
// Needs loop emitter to set up loop bounds for synthetic tensor too if
// there is a loop condition imposed on the synthetic tensor.
tidLvls.push_back(env.makeTensorLevel(tid, env.getCurrentDepth()));
}
needsUniv = true;
}
if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
is2OutOf4LT(lt) || isIdxReduc) {
// Only when this is a index reduction loop, can the lt be undefined.
assert(!isUndefLT(lt) || isIdxReduc);
// sparse/singleton levels, or a dense/sparse index reduction loop.
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
}
});

env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);

// Maintain the universal index only if it is actually
// consumed by a subsequent lattice point.
if (needsUniv) {
for (const LatPointId li : env.set(lts).drop_front())
if (!env.merger().hasAnySparse(env.lat(li).simple))
return true;
}
return false;
}

// Generates dense affine address for encoding.
static void genConstantDenseAddressFromLevel(CodegenEnv &env,
OpBuilder &builder, TensorId tid,
Level startLvl) {
// TODO: Handle affine expression on output tensor.
linalg::GenericOp op = env.op();
assert(tid < op.getNumDpsInputs());
OpOperand *input = op.getDpsInputOperands()[tid];
const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
const auto enc = getSparseTensorEncoding(input->get().getType());
if (enc) {
const Location loc = op.getLoc();
const TensorId tid = env.makeTensorId(input->getOperandNumber());
const Level lvlRank = enc.getLvlRank();
assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
for (Level l = startLvl; l < lvlRank; l++) {
AffineExpr lvlExpr = lvlExprs[l];
if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
env.emitter().genDenseAffineAddress(
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
else
return; // break on first non-dense non-constant level
}
}
}

// We can generate address for constant affine expression before any loops
// starting from the first level as they do not depend on any thing.
// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
// levels can be determined before loops.
static void genInitConstantDenseAddress(CodegenEnv &env,
RewriterBase &rewriter) {
for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
}

/// Return true if the lattices bit can be iterated by a for loop.
static bool translateBitsToTidLvlPairs(
static bool getAllTidLvlsInLatPoints(
CodegenEnv &env, LatPointId li, LoopId curr,
SmallVectorImpl<TensorLevel> &tidLvls,
SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
const BitVector &simple = env.lat(li).simple;
const TensorId outTid = env.merger().getOutTensorID();
const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
Expand All @@ -1048,7 +963,7 @@ static bool translateBitsToTidLvlPairs(
LevelType lt, bool isIdxReduc) {
if (simple[b]) {
if (isIdxReduc) {
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
return;
}
Expand All @@ -1072,10 +987,10 @@ static bool translateBitsToTidLvlPairs(
}
}
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
} else if (isDenseLT(lt) || isIdxReduc) {
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
callback(env.makeTensorLevel(tid, *lvl), nullptr);
} else {
assert(isUndefLT(lt));
linalg::GenericOp op = env.op();
Expand Down Expand Up @@ -1109,7 +1024,7 @@ static bool translateBitsToTidLvlPairs(
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
// to avoid putting it inside inner loops.
affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
callback(env.makeTensorLevel(tid, l), exp);
}
}
}
Expand All @@ -1120,22 +1035,99 @@ static bool translateBitsToTidLvlPairs(
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
}

if (numloopCond == 0) {
// Corner cases where the loop bound is defined by a *unused* operand, in
// this case, we just generate a dense "fake" loop by iterating over the
// synthetic tensor.
tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
env.getCurrentDepth()));
callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
numloopCond++;
}
// If we just need to one loop conditions and the conditions is not imposed on
// non-unique level, the loop can be generated by a for loop.
return numloopCond == 1 && !hasNonUnique;
}

/// Starts a loop sequence at given level. Returns true if
/// the universal loop index must be maintained at this level.
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
LoopId curr, LatSetId lts) {
assert(!env.getLoopVar(curr));
// Emit invariants at this loop sequence level.
genInvariants(env, builder, exp, curr, /*isStart=*/true);
// Emit access pattern expansion for sparse tensor output.
genExpand(env, builder, curr, /*isStart=*/true);
// Emit further initialization at this loop sequence level.
const LatPointId l0 = env.set(lts)[0];

SmallVector<TensorLevel> tidLvls;
getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
tidLvls.emplace_back(tl);
});

env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);

// Maintain the universal index only if it is actually
// consumed by a subsequent lattice point.
for (const LatPointId li : env.set(lts).drop_front())
if (!env.merger().hasAnySparse(env.lat(li).simple))
return true;

return false;
}

// Generates dense affine address for encoding.
static void genConstantDenseAddressFromLevel(CodegenEnv &env,
OpBuilder &builder, TensorId tid,
Level startLvl) {
// TODO: Handle affine expression on output tensor.
linalg::GenericOp op = env.op();
assert(tid < op.getNumDpsInputs());
OpOperand *input = op.getDpsInputOperands()[tid];
const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
const auto enc = getSparseTensorEncoding(input->get().getType());
if (enc) {
const Location loc = op.getLoc();
const TensorId tid = env.makeTensorId(input->getOperandNumber());
const Level lvlRank = enc.getLvlRank();
assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
for (Level l = startLvl; l < lvlRank; l++) {
AffineExpr lvlExpr = lvlExprs[l];
if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
env.emitter().genDenseAffineAddress(
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
else
return; // break on first non-dense non-constant level
}
}
}

// We can generate address for constant affine expression before any loops
// starting from the first level as they do not depend on anything.
// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
// levels can be determined before loops.
static void genInitConstantDenseAddress(CodegenEnv &env,
RewriterBase &rewriter) {
for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
}

/// Returns true if the lattice bit can be iterated by a for loop.
static bool translateBitsToTidLvlPairs(
CodegenEnv &env, LatPointId li, LoopId curr,
SmallVectorImpl<TensorLevel> &tidLvls,
SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
return getAllTidLvlsInLatPoints(env, li, curr,
[&](TensorLevel tl, AffineExpr exp) {
if (exp)
affineTidLvls.emplace_back(tl, exp);
else
tidLvls.emplace_back(tl);
});
}

/// Starts a single loop in current sequence.
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
OpBuilder &builder, LoopId curr,
Expand Down