Skip to content

Commit d933b88

Browse files
authored
[mlir][sparse] use a common util function to query the tensor level s… (#76764)
…et in a lattice point.
1 parent 1a8fb88 commit d933b88

File tree

1 file changed

+86
-94
lines changed

1 file changed

+86
-94
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 86 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -949,94 +949,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
949949
// Sparsifier synthesis methods (loop sequence).
950950
//===----------------------------------------------------------------------===//
951951

952-
/// Starts a loop sequence at given level. Returns true if
953-
/// the universal loop index must be maintained at this level.
954-
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
955-
LoopId curr, LatSetId lts) {
956-
assert(!env.getLoopVar(curr));
957-
// Emit invariants at this loop sequence level.
958-
genInvariants(env, builder, exp, curr, /*isStart=*/true);
959-
// Emit access pattern expansion for sparse tensor output.
960-
genExpand(env, builder, curr, /*isStart=*/true);
961-
// Emit further intitialization at this loop sequence level.
962-
const LatPointId l0 = env.set(lts)[0];
963-
bool needsUniv = false;
964-
965-
SmallVector<TensorLevel> tidLvls;
966-
env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
967-
std::optional<Level> lvl,
968-
LevelType lt, bool isIdxReduc) {
969-
assert(env.merger().loop(b) == curr);
970-
if (isDenseLT(lt) || isUndefLT(lt)) {
971-
if (tid == env.merger().getSynTensorID()) {
972-
// Needs loop emitter to set up loop bounds for synthetic tensor too if
973-
// there is a loop condition imposed on the synthetic tensor.
974-
tidLvls.push_back(env.makeTensorLevel(tid, env.getCurrentDepth()));
975-
}
976-
needsUniv = true;
977-
}
978-
if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
979-
is2OutOf4LT(lt) || isIdxReduc) {
980-
// Only when this is a index reduction loop, can the lt be undefined.
981-
assert(!isUndefLT(lt) || isIdxReduc);
982-
// sparse/singleton levels, or a dense/sparse index reduction loop.
983-
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
984-
}
985-
});
986-
987-
env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
988-
989-
// Maintain the universal index only if it is actually
990-
// consumed by a subsequent lattice point.
991-
if (needsUniv) {
992-
for (const LatPointId li : env.set(lts).drop_front())
993-
if (!env.merger().hasAnySparse(env.lat(li).simple))
994-
return true;
995-
}
996-
return false;
997-
}
998-
999-
// Generates dense affine address for encoding.
1000-
static void genConstantDenseAddressFromLevel(CodegenEnv &env,
1001-
OpBuilder &builder, TensorId tid,
1002-
Level startLvl) {
1003-
// TODO: Handle affine expression on output tensor.
1004-
linalg::GenericOp op = env.op();
1005-
assert(tid < op.getNumDpsInputs());
1006-
OpOperand *input = op.getDpsInputOperands()[tid];
1007-
const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1008-
const auto enc = getSparseTensorEncoding(input->get().getType());
1009-
if (enc) {
1010-
const Location loc = op.getLoc();
1011-
const TensorId tid = env.makeTensorId(input->getOperandNumber());
1012-
const Level lvlRank = enc.getLvlRank();
1013-
assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1014-
for (Level l = startLvl; l < lvlRank; l++) {
1015-
AffineExpr lvlExpr = lvlExprs[l];
1016-
if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
1017-
env.emitter().genDenseAffineAddress(
1018-
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1019-
else
1020-
return; // break on first non-dense non-constant level
1021-
}
1022-
}
1023-
}
1024-
1025-
// We can generate address for constant affine expression before any loops
1026-
// starting from the first level as they do not depend on any thing.
1027-
// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1028-
// levels can be determined before loops.
1029-
static void genInitConstantDenseAddress(CodegenEnv &env,
1030-
RewriterBase &rewriter) {
1031-
for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1032-
genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1033-
}
1034-
1035-
/// Return true if the lattices bit can be iterated by a for loop.
1036-
static bool translateBitsToTidLvlPairs(
952+
static bool getAllTidLvlsInLatPoints(
1037953
CodegenEnv &env, LatPointId li, LoopId curr,
1038-
SmallVectorImpl<TensorLevel> &tidLvls,
1039-
SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
954+
llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
1040955
const BitVector &simple = env.lat(li).simple;
1041956
const TensorId outTid = env.merger().getOutTensorID();
1042957
const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
@@ -1048,7 +963,7 @@ static bool translateBitsToTidLvlPairs(
1048963
LevelType lt, bool isIdxReduc) {
1049964
if (simple[b]) {
1050965
if (isIdxReduc) {
1051-
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
966+
callback(env.makeTensorLevel(tid, *lvl), nullptr);
1052967
numloopCond++;
1053968
return;
1054969
}
@@ -1072,10 +987,10 @@ static bool translateBitsToTidLvlPairs(
1072987
}
1073988
}
1074989
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
1075-
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
990+
callback(env.makeTensorLevel(tid, *lvl), nullptr);
1076991
numloopCond++;
1077992
} else if (isDenseLT(lt) || isIdxReduc) {
1078-
tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
993+
callback(env.makeTensorLevel(tid, *lvl), nullptr);
1079994
} else {
1080995
assert(isUndefLT(lt));
1081996
linalg::GenericOp op = env.op();
@@ -1109,7 +1024,7 @@ static bool translateBitsToTidLvlPairs(
11091024
// level. We need to generate the address according to the
11101025
// affine expression. This is also the best place we can do it
11111026
// to avoid putting it inside inner loops.
1112-
affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
1027+
callback(env.makeTensorLevel(tid, l), exp);
11131028
}
11141029
}
11151030
}
@@ -1120,22 +1035,99 @@ static bool translateBitsToTidLvlPairs(
11201035
// Note that we generate dense indices of the output tensor
11211036
// unconditionally, since they may not appear in the lattice, but may be
11221037
// needed for linearized env.
1123-
tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
1038+
callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
11241039
}
11251040

11261041
if (numloopCond == 0) {
11271042
// Corner cases where the loop bound is defined by a *unused* operand, in
11281043
// this case, we just generate a dense "fake" loop by iterating over the
11291044
// synthetic tensor.
1130-
tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
1131-
env.getCurrentDepth()));
1045+
callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
11321046
numloopCond++;
11331047
}
11341048
// If we just need to one loop conditions and the conditions is not imposed on
11351049
// non-unique level, the loop can be generated by a for loop.
11361050
return numloopCond == 1 && !hasNonUnique;
11371051
}
11381052

1053+
/// Starts a loop sequence at given level. Returns true if
1054+
/// the universal loop index must be maintained at this level.
1055+
static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
1056+
LoopId curr, LatSetId lts) {
1057+
assert(!env.getLoopVar(curr));
1058+
// Emit invariants at this loop sequence level.
1059+
genInvariants(env, builder, exp, curr, /*isStart=*/true);
1060+
// Emit access pattern expansion for sparse tensor output.
1061+
genExpand(env, builder, curr, /*isStart=*/true);
1062+
// Emit further initialization at this loop sequence level.
1063+
const LatPointId l0 = env.set(lts)[0];
1064+
1065+
SmallVector<TensorLevel> tidLvls;
1066+
getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1067+
tidLvls.emplace_back(tl);
1068+
});
1069+
1070+
env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
1071+
1072+
// Maintain the universal index only if it is actually
1073+
// consumed by a subsequent lattice point.
1074+
for (const LatPointId li : env.set(lts).drop_front())
1075+
if (!env.merger().hasAnySparse(env.lat(li).simple))
1076+
return true;
1077+
1078+
return false;
1079+
}
1080+
1081+
// Generates dense affine address for encoding.
1082+
static void genConstantDenseAddressFromLevel(CodegenEnv &env,
1083+
OpBuilder &builder, TensorId tid,
1084+
Level startLvl) {
1085+
// TODO: Handle affine expression on output tensor.
1086+
linalg::GenericOp op = env.op();
1087+
assert(tid < op.getNumDpsInputs());
1088+
OpOperand *input = op.getDpsInputOperands()[tid];
1089+
const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
1090+
const auto enc = getSparseTensorEncoding(input->get().getType());
1091+
if (enc) {
1092+
const Location loc = op.getLoc();
1093+
const TensorId tid = env.makeTensorId(input->getOperandNumber());
1094+
const Level lvlRank = enc.getLvlRank();
1095+
assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1096+
for (Level l = startLvl; l < lvlRank; l++) {
1097+
AffineExpr lvlExpr = lvlExprs[l];
1098+
if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
1099+
env.emitter().genDenseAffineAddress(
1100+
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
1101+
else
1102+
return; // break on first non-dense non-constant level
1103+
}
1104+
}
1105+
}
1106+
1107+
// We can generate address for constant affine expression before any loops
1108+
// starting from the first level as they do not depend on anything.
1109+
// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1110+
// levels can be determined before loops.
1111+
static void genInitConstantDenseAddress(CodegenEnv &env,
1112+
RewriterBase &rewriter) {
1113+
for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
1114+
genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
1115+
}
1116+
1117+
/// Returns true if the lattice bit can be iterated by a for loop.
1118+
static bool translateBitsToTidLvlPairs(
1119+
CodegenEnv &env, LatPointId li, LoopId curr,
1120+
SmallVectorImpl<TensorLevel> &tidLvls,
1121+
SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1122+
return getAllTidLvlsInLatPoints(env, li, curr,
1123+
[&](TensorLevel tl, AffineExpr exp) {
1124+
if (exp)
1125+
affineTidLvls.emplace_back(tl, exp);
1126+
else
1127+
tidLvls.emplace_back(tl);
1128+
});
1129+
}
1130+
11391131
/// Starts a single loop in current sequence.
11401132
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
11411133
OpBuilder &builder, LoopId curr,

0 commit comments

Comments
 (0)