Skip to content

Commit ccd923e

Browse files
authored
[mlir][sparse] code cleanup (remove dead code related to filter loop). (#72573)
1 parent 4639610 commit ccd923e

File tree

6 files changed

+38
-139
lines changed

6 files changed

+38
-139
lines changed

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -223,29 +223,16 @@ struct LatPoint final {
223223
/// independently from the basic algorithm if bottlenecks are identified.
224224
class Merger {
225225
public:
226-
/// Constructs a merger for the given number of tensors, native loops, and
227-
/// filter loops. The user supplies the number of tensors involved in the
228-
/// kernel, with the last tensor in this set denoting the output tensor.
229-
/// The merger adds an additional synthetic tensor at the end of this set
230-
/// to represent all invariant expressions in the kernel.
231-
///
232-
/// In addition to natives loops (which are specified by the GenericOp),
233-
/// extra filter loops are needed in order to handle affine expressions on
234-
/// sparse levels. E.g., (d0, d1, d2) => (d0 + d1, d2), a naive
235-
/// implementation of the filter loop could be generated as
236-
///
237-
/// for (const auto c0 : coordinates[0]) {
238-
/// if (c0 == d0 + d1) {
239-
/// generated_code;
240-
/// }
241-
/// }
242-
///
243-
/// to filter out coordinates that are not equal to the affine expression.
226+
/// Constructs a merger for the given number of tensors and loops. The user
227+
/// supplies the number of tensors involved in the kernel, with the last
228+
/// tensor in this set denoting the output tensor. The merger adds an
229+
/// additional synthetic tensor at the end of this set to represent all
230+
/// invariant expressions in the kernel.
244231
///
245232
/// The maxLvlRank specifies the max level rank of all inputs/output tensors.
246233
/// It is used to pre-allocate sufficient memory for internal storage.
247-
Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
248-
unsigned numFilterLoops, unsigned maxLvlRank);
234+
Merger(unsigned numInputOutputTensors, unsigned numLoops,
235+
unsigned maxLvlRank);
249236

250237
//
251238
// Constructing valid tensor and loop identifiers.
@@ -366,19 +353,6 @@ class Merger {
366353
/// Gets the total number of loops (native loops + filter loops).
367354
constexpr unsigned getNumLoops() const { return numLoops; }
368355

369-
/// Gets the number of native loops.
370-
constexpr unsigned getNumNativeLoops() const { return numNativeLoops; }
371-
372-
/// Gets the number of filter loops.
373-
constexpr unsigned getNumFilterLoops() const {
374-
return numLoops - numNativeLoops;
375-
}
376-
377-
/// Gets the identifier of the first filter-loop.
378-
constexpr LoopId getStartingFilterLoopId() const {
379-
return getNumNativeLoops();
380-
}
381-
382356
/// Returns true if `b` is the `i`th loop of the output tensor.
383357
constexpr bool isOutTensor(TensorLoopId b, LoopId i) const {
384358
return b == makeTensorLoopId(outTensor, i);
@@ -391,11 +365,6 @@ class Merger {
391365
/// tensor expressions).
392366
constexpr TensorId getSynTensorID() const { return syntheticTensor; }
393367

394-
constexpr bool isFilterLoop(LoopId i) const {
395-
assert(isValidLoopId(i));
396-
return i >= numNativeLoops;
397-
}
398-
399368
/// Returns true if the expression is `(kTensor t)`.
400369
bool expIsTensor(ExprId e, TensorId t) const {
401370
const auto &expr = exp(e);
@@ -657,7 +626,6 @@ class Merger {
657626
const TensorId outTensor;
658627
const TensorId syntheticTensor;
659628
const unsigned numTensors;
660-
const unsigned numNativeLoops;
661629
const unsigned numLoops;
662630
bool hasSparseOut;
663631

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,12 @@ static void sortDependentLoops(std::vector<LoopCoeffPair> &target) {
4242
//===----------------------------------------------------------------------===//
4343

4444
CodegenEnv::CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
45-
unsigned numTensors, unsigned numLoops,
46-
unsigned numFilterLoops, unsigned maxRank)
45+
unsigned numTensors, unsigned numLoops, unsigned maxRank)
4746
: linalgOp(linop), sparseOptions(opts),
48-
latticeMerger(numTensors, numLoops, numFilterLoops, maxRank),
49-
loopEmitter(), sparseOut(nullptr), outerParNest(-1u), insChain(),
50-
expValues(), expFilled(), expAdded(), expCount(), redVal(),
51-
redExp(detail::kInvalidId), redCustom(detail::kInvalidId),
52-
redValidLexInsert() {}
47+
latticeMerger(numTensors, numLoops, maxRank), loopEmitter(),
48+
sparseOut(nullptr), outerParNest(-1u), insChain(), expValues(),
49+
expFilled(), expAdded(), expCount(), redVal(), redExp(detail::kInvalidId),
50+
redCustom(detail::kInvalidId), redValidLexInsert() {}
5351

5452
LogicalResult CodegenEnv::initTensorExp() {
5553
// Builds the tensor expression for the Linalg operation in SSA form.

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ class CodegenEnv {
3838
/// passed around during sparsification for bookkeeping
3939
/// together with some consistency asserts.
4040
CodegenEnv(linalg::GenericOp linop, SparsificationOptions opts,
41-
unsigned numTensors, unsigned numLoops, unsigned numFilterLoops,
42-
unsigned maxRank);
41+
unsigned numTensors, unsigned numLoops, unsigned maxRank);
4342

4443
//
4544
// General methods.

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 20 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,8 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
7878
/// Helper method to inspect affine expressions. Rejects cases where the
7979
/// same index is used more than once. Also rejects compound affine
8080
/// expressions in sparse dimensions.
81-
/// filterIdx stores the current filter loop idx should be used for the next
82-
/// compound affine sparse level, and it will be incremented by one when
83-
/// used.
8481
static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
85-
DimLevelType dlt, LoopId &filterLdx,
86-
bool setLvlFormat = true) {
82+
DimLevelType dlt, bool setLvlFormat = true) {
8783
switch (a.getKind()) {
8884
case AffineExprKind::DimId: {
8985
const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
@@ -97,22 +93,14 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
9793
case AffineExprKind::Add:
9894
case AffineExprKind::Mul:
9995
case AffineExprKind::Constant: {
100-
if (!isDenseDLT(dlt) && setLvlFormat) {
101-
assert(isUndefDLT(merger.getLvlType(tid, filterLdx)));
102-
// Use a filter loop for sparse affine expression.
103-
merger.setLevelAndType(tid, filterLdx, lvl, dlt);
104-
++filterLdx;
105-
}
106-
96+
assert(isDenseDLT(dlt));
10797
if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
10898
// We do not set dim level format for affine expression like d0 + d1 on
10999
// either loop index at d0 or d1.
110100
// We continue the recursion merely to check whether current affine is
111101
// admissible or not.
112-
return findAffine(merger, tid, lvl, binOp.getLHS(), dlt, filterLdx,
113-
false) &&
114-
findAffine(merger, tid, lvl, binOp.getRHS(), dlt, filterLdx,
115-
false);
102+
return findAffine(merger, tid, lvl, binOp.getLHS(), dlt, false) &&
103+
findAffine(merger, tid, lvl, binOp.getRHS(), dlt, false);
116104
}
117105
// Falls through when it is a constant Affine
118106
return true;
@@ -225,32 +213,13 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
225213
return 0;
226214
const SparseTensorType stt(rtp);
227215

228-
// FIXME: There's some dim/lvl confusion here. The previous version of
229-
// the code asserted that there are `lvlRank`-many expressions, but then
230-
// the `exprs[d]` expression assumes there are in fact `dimRank`-many
231-
// expressions. Even though `ArrayRef::operator[]` will check for OOB,
232-
// the mismatch between the assertion and the usage belies that this code
233-
// cannot support non-permutations.
234-
//
235-
// Elsewhere in this file the maps returned by
236-
// `linalg::GenericOp::getMatchingIndexingMap` are inconsistent about
237-
// whether they're expected to have `lvlRank`-many or `dimRank`-many
238-
// expressions (cf., `genSubscript` vs `findSparseAnnotations`);
239-
// so those are no help in determining which is actually intended.
240-
//
241-
// For now we work around this problem by asserting the two ranks agree.
242-
const Dimension dimRank = stt.getDimRank();
243216
const Level lvlRank = stt.getLvlRank();
244-
assert(dimRank == lvlRank && "Non-permutations not currently supported");
245217
const auto exprs = map.getResults();
246-
assert(static_cast<Dimension>(exprs.size()) == dimRank &&
218+
assert(static_cast<Dimension>(exprs.size()) == lvlRank &&
247219
"AffineMap does not have dimension-rank many results");
248-
(void)dimRank;
249220
unsigned num = 0;
250221
for (Level l = 0; l < lvlRank; l++) {
251-
// FIXME: `toOrigDim` is deprecated.
252-
const Dimension d = toOrigDim(stt.getEncoding(), l);
253-
if (!isa<AffineDimExpr>(exprs[d]) && !stt.isDenseLvl(l))
222+
if (!isa<AffineDimExpr>(exprs[l]) && !stt.isDenseLvl(l))
254223
num++;
255224
}
256225
return num;
@@ -281,15 +250,10 @@ static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
281250
/// no annotations are found or inadmissible constructs occur.
282251
/// We currently support two different ways to handle non-trivial index
283252
/// expression on sparse tensors, and they accept different affine expressions.
284-
/// When using filter-loop-based approach, it accept (almost) arbitrary affine
285-
/// index expression on sparse tensor but it is much less efficient, and will be
286-
/// gradually removed from the codebase.
287253
/// When using dependent index reducton-based approach, it currently only
288254
/// supports affine addition index expression.
289255
static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
290256
bool annotated = false;
291-
// `filterLdx` may be mutated by `findAffine`.
292-
LoopId filterLdx = env.merger().getStartingFilterLoopId();
293257
for (OpOperand &t : env.op()->getOpOperands()) {
294258
const TensorId tid = env.makeTensorId(t.getOperandNumber());
295259
const auto map = env.op().getMatchingIndexingMap(&t);
@@ -310,19 +274,17 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
310274
// If then current tensor being inspected requires affine index, it need
311275
// to be sliced.
312276
for (Level l = 0; l < lvlRank; l++) {
313-
// FIXME: `toOrigDim` is deprecated.
314-
const AffineExpr a = map.getResult(toOrigDim(enc, l));
277+
const AffineExpr a = map.getResult(l);
315278
const DimLevelType dlt = enc.getLvlType(l);
316279
if (idxReducBased && needIdxReduc) {
317280
if (!findDepIdxSet(env.merger(), tid, l, a, dlt))
318281
return false; // inadmissible affine expression
319282
} else {
320-
if (!findAffine(env.merger(), tid, l, a, dlt, filterLdx))
283+
if (!findAffine(env.merger(), tid, l, a, dlt))
321284
return false; // inadmissible affine expression
322285
}
323286
}
324287
}
325-
assert(filterLdx == env.merger().getNumLoops());
326288
return annotated;
327289
}
328290

@@ -374,13 +336,8 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
374336
}
375337
return init;
376338
},
377-
[&loopRange, &env](OpBuilder &b, Location loc, Level l) {
378-
assert(l < env.getLoopNum());
379-
// FIXME: Remove filter loop since we have a better algorithm to
380-
// deal with affine index expression.
381-
if (l >= env.merger().getStartingFilterLoopId())
382-
return Value();
383-
339+
[&loopRange](OpBuilder &b, Location loc, Level l) {
340+
assert(l < loopRange.size());
384341
return mlir::getValueOrCreateConstantIndexOp(b, loc, loopRange[l].size);
385342
});
386343
}
@@ -394,10 +351,7 @@ static Value genIndex(CodegenEnv &env, OpOperand *t) {
394351
const auto stt = getSparseTensorType(t->get());
395352
const Level lvlRank = stt.getLvlRank();
396353
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
397-
// FIXME: `toOrigDim` is deprecated.
398-
// FIXME: above we asserted that there are `lvlRank` many results,
399-
// but this is assuming there are in fact `dimRank` many results instead.
400-
const AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), lvlRank - 1));
354+
const AffineExpr a = map.getResult(lvlRank - 1);
401355
assert(a.getKind() == AffineExprKind::DimId);
402356
const LoopId idx = env.makeLoopId(cast<AffineDimExpr>(a).getPosition());
403357
return env.getLoopVar(idx);
@@ -727,19 +681,8 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
727681
const Level lvlRank = stt.getLvlRank();
728682
assert(static_cast<Level>(map.getNumResults()) == lvlRank);
729683
for (Level l = 0; l < lvlRank; l++) {
730-
// FIXME: `toOrigDim` is deprecated.
731-
// FIXME: above we asserted that there are `lvlRank` many results,
732-
// but this is assuming there are in fact `dimRank` many results instead.
733-
const AffineExpr a = map.getResult(toOrigDim(stt.getEncoding(), l));
734-
const auto sldx =
735-
env.merger().getLoopId(env.makeTensorId(t.getOperandNumber()), l);
736-
if (sldx && env.merger().isFilterLoop(*sldx)) {
737-
if (!env.getLoopVar(*sldx))
738-
// The filter loops has not been constructed.
739-
return;
740-
if (*sldx == ldx)
741-
isAtLoop = true;
742-
} else if (!isInvariantAffine(a, env.getLoopDepth(), ldx, isAtLoop))
684+
const AffineExpr a = map.getResult(l);
685+
if (!isInvariantAffine(a, env.getLoopDepth(), ldx, isAtLoop))
743686
return; // still in play
744687
}
745688
// All exhausted at this level (isAtLoop denotes exactly at this LoopId).
@@ -1073,10 +1016,8 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
10731016
const TensorId tid = env.makeTensorId(input->getOperandNumber());
10741017
const Level lvlRank = enc.getLvlRank();
10751018
assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
1076-
// FIXME: there is dim/lvl confusion here
10771019
for (Level l = startLvl; l < lvlRank; l++) {
1078-
// FIXME: `toOrigDim` is deprecated.
1079-
AffineExpr lvlExpr = lvlExprs[toOrigDim(enc, l)];
1020+
AffineExpr lvlExpr = lvlExprs[l];
10801021
if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
10811022
env.emitter().genDenseAffineAddress(
10821023
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
@@ -1164,8 +1105,7 @@ static bool translateBitsToTidLvlPairs(
11641105
const Level lvlRank = stt.getLvlRank();
11651106
assert(affines.size() == static_cast<size_t>(lvlRank));
11661107
for (Level l = 0; l < lvlRank; l++) {
1167-
// FIXME: `toOrigDim` is deprecated.
1168-
AffineExpr exp = affines[toOrigDim(stt.getEncoding(), l)];
1108+
AffineExpr exp = affines[l];
11691109
// Skip simple affine expression and non-dense levels (which
11701110
// have their own filter loop).
11711111
if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
@@ -1396,14 +1336,13 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
13961336
op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
13971337
"before sparsification.");
13981338
}
1339+
// Must have been demapped as well if the generic op is sorted.
1340+
assert(!hasAnyNonIdentityOperandsOrResults(op));
13991341

14001342
// Sets up a code generation environment.
14011343
const unsigned numTensors = op->getNumOperands();
14021344
const unsigned numLoops = op.getNumLoops();
1403-
const unsigned numFilterLoops = getNumNonTrivialIdxExpOnSparseLvls(op);
1404-
// TODO: we should probably always use slice-based codegen whenever
1405-
// possible, we can even intermix slice-based and filter-loop based codegen.
1406-
bool idxReducBased = numFilterLoops != 0;
1345+
bool needIdxRed = getNumNonTrivialIdxExpOnSparseLvls(op) != 0;
14071346
// If we have indexing map like (d0) -> (0, d0), there might be more
14081347
// levels then loops because of the constant index, that means we can not
14091348
// use numLoops as the upper bound for ranks of all tensors.
@@ -1417,14 +1356,10 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
14171356
}
14181357
}
14191358

1420-
// A slice based algorithm for affine indices does not need filter loops.
1421-
CodegenEnv env(op, options, numTensors, numLoops,
1422-
/*numFilterLoops=*/idxReducBased ? 0 : numFilterLoops,
1423-
maxLvlRank);
1424-
1359+
CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
14251360
// Detects sparse annotations and translates the per-level sparsity
14261361
// information for all tensors to loop indices in the kernel.
1427-
if (!findSparseAnnotations(env, idxReducBased))
1362+
if (!findSparseAnnotations(env, needIdxRed))
14281363
return failure();
14291364

14301365
// Only standard reduction operations (add, sub, or, xor) that can be

mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,12 @@ TensorExp::TensorExp(TensorExp::Kind k, unsigned x, ExprId y, Value v,
220220
llvm_unreachable("unexpected kind");
221221
}
222222

223-
Merger::Merger(unsigned numInputOutputTensors, unsigned numNativeLoops,
224-
unsigned numFilterLoops, unsigned maxLvlRank)
223+
Merger::Merger(unsigned numInputOutputTensors, unsigned numLoops,
224+
unsigned maxLvlRank)
225225
: outTensor(numInputOutputTensors - 1),
226226
syntheticTensor(numInputOutputTensors),
227-
numTensors(numInputOutputTensors + 1), numNativeLoops(numNativeLoops),
228-
numLoops(numNativeLoops + numFilterLoops), hasSparseOut(false),
227+
numTensors(numInputOutputTensors + 1), numLoops(numLoops),
228+
hasSparseOut(false),
229229
lvlTypes(numTensors,
230230
std::vector<DimLevelType>(numLoops, DimLevelType::Undef)),
231231
loopToLvl(numTensors,

mlir/unittests/Dialect/SparseTensor/MergerTest.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ FOREVERY_BINOP(IMPL_BINOP_PATTERN)
123123
class MergerTestBase : public ::testing::Test {
124124
protected:
125125
MergerTestBase(unsigned numTensors, unsigned numLoops)
126-
: merger(numTensors, numLoops, /*numFilterLoops=*/0,
127-
/*maxRank=*/numLoops) {
126+
: merger(numTensors, numLoops, /*maxRank=*/numLoops) {
128127
tensors.reserve(numTensors);
129128
for (unsigned t = 0; t < numTensors; t++)
130129
tensors.push_back(merger.addTensorExp(tid(t)));

0 commit comments

Comments
 (0)