Skip to content

Commit e015d38

Browse files
author
Peiming Liu
committed
[mlir][sparse] Pass down constant coefficients of affine index expressions to LoopEmitter.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D158914
1 parent 680da4b commit e015d38

File tree

7 files changed

+181
-108
lines changed

7 files changed

+181
-108
lines changed

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ using LatPointId = unsigned;
5656
/// for the corresponding `SmallVector<LatPointId>` object.
5757
using LatSetId = unsigned;
5858

59+
/// A pair of level and its corresponding DimLevelType of a tensor.
60+
using LvlDLTPair = std::pair<Level, DimLevelType>;
61+
62+
/// A pair of loop id and its coefficients. E.g., for affine expression in the
63+
/// affine map `2 * d0`, loop id = 0, coefficient = 2.
64+
using LoopCoeffPair = std::pair<LoopId, unsigned>;
65+
5966
/// Tensor expression. Represents an MLIR expression in tensor index notation.
6067
struct TensorExp final {
6168
enum class Kind;
@@ -509,22 +516,22 @@ class Merger {
509516

510517
/// Establishes the two-way map that i <-> <t, lvl, dlt>.
511518
void setLoopDependentTensorLevel(LoopId i, TensorId t, Level lvl,
512-
DimLevelType dlt) {
519+
DimLevelType dlt, unsigned coefficient) {
513520
assert(isValidLoopId(i) && isValidLevel(t, lvl));
514-
assert(!loopToDependencies[i][t].has_value()); // must be the first def
515-
loopToDependencies[i][t] = std::make_pair(lvl, dlt);
516-
levelToDependentLoop[t][lvl].push_back(i);
521+
assert(!loopToUnresolvedLvls[i][t].has_value()); // must be the first def
522+
loopToUnresolvedLvls[i][t] = std::make_pair(lvl, dlt);
523+
levelToDependentLoop[t][lvl].emplace_back(i, coefficient);
517524
}
518525

519526
/// Whether the loop has dependent slice.
520527
bool hasDependentLvl(LoopId i, TensorId t) {
521528
assert(isValidTensorId(t) && isValidLoopId(i));
522-
return loopToDependencies[i][t].has_value();
529+
return loopToUnresolvedLvls[i][t].has_value();
523530
}
524531

525532
/// Returns the list of loop indices which appear in the non-trivial index
526533
/// expression on t_l, e.g., A[i+j] => {i, j}
527-
std::vector<LoopId> &getDependentLoops(TensorId t, Level lvl) {
534+
std::vector<LoopCoeffPair> &getDependentLoops(TensorId t, Level lvl) {
528535
assert(isValidLevel(t, lvl));
529536
return levelToDependentLoop[t][lvl];
530537
}
@@ -541,7 +548,7 @@ class Merger {
541548
const TensorId t = tensor(b);
542549
const LoopId i = loop(b);
543550
assert(isValidTensorId(t) && isValidLoopId(i));
544-
return loopToDependencies[i][t].has_value();
551+
return loopToUnresolvedLvls[i][t].has_value();
545552
}
546553

547554
/// Checks whether the TensorLoopId represents a sparse tensor level contains
@@ -556,12 +563,12 @@ class Merger {
556563

557564
Level getLoopDependentLevel(TensorLoopId b) const {
558565
assert(isLvlWithNonTrivialIdxExp(b));
559-
return loopToDependencies[loop(b)][tensor(b)]->first;
566+
return loopToUnresolvedLvls[loop(b)][tensor(b)]->first;
560567
}
561568

562569
DimLevelType getLoopDependentLevelType(TensorLoopId b) const {
563570
assert(isLvlWithNonTrivialIdxExp(b));
564-
return loopToDependencies[loop(b)][tensor(b)]->second;
571+
return loopToUnresolvedLvls[loop(b)][tensor(b)]->second;
565572
}
566573

567574
/// Convenience getters to immediately access the stored nodes.
@@ -715,13 +722,13 @@ class Merger {
715722
/// It is currently only set for non-trivial index expressions.
716723
/// E.g., A[i+j] => i and j will have dependencies {A0, dlt(A0)} to indicate
717724
/// that i and j are used in the non-trivial index expression on A0.
718-
std::vector<std::vector<std::optional<std::pair<Level, DimLevelType>>>>
719-
loopToDependencies;
725+
std::vector<std::vector<std::optional<LvlDLTPair>>> loopToUnresolvedLvls;
720726

721727
/// The inverse map of ldxToDependencies from tensor level -> dependent loop
722-
/// E.g., A[i+j], we have A0 => {i, j}, to indicate that A0 uses both {i, j}
723-
/// to compute its indices.
724-
std::vector<std::vector<std::vector<LoopId>>> levelToDependentLoop;
728+
/// E.g., A[2i+j], we have A0 => {(2, i), (1, j)}, to indicate that A0 uses
729+
/// both {i, j} to compute its indices and the coefficients on the loop id are
730+
/// 2 and 1 respectively.
731+
std::vector<std::vector<std::vector<LoopCoeffPair>>> levelToDependentLoop;
725732

726733
/// Map from a loop to the [tid, lvl] pair that defines the loop boundary.
727734
std::vector<std::pair<TensorId, Level>> loopBounds;

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ static bool isMaterializing(Value val) {
2929
}
3030

3131
/// Makes target array's elements sorted according to the `order` array.
32-
static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
32+
static void sortArrayBasedOnOrder(std::vector<LoopCoeffPair> &target,
3333
ArrayRef<LoopId> order) {
3434
std::sort(target.begin(), target.end(),
35-
[&order](const LoopId &l, const LoopId &r) {
35+
[&order](const LoopCoeffPair &l, const LoopCoeffPair &r) {
3636
assert(std::addressof(l) == std::addressof(r) || l != r);
3737
int idxL = -1, idxR = -1;
3838
for (int i = 0, e = order.size(); i < e; i++) {
39-
if (order[i] == l)
39+
if (order[i] == l.first)
4040
idxL = i;
41-
if (order[i] == r)
41+
if (order[i] == r.first)
4242
idxR = i;
4343
}
4444
assert(idxL >= 0 && idxR >= 0);
@@ -104,13 +104,17 @@ void CodegenEnv::startEmit() {
104104
/*isSparseOut=*/sparseOut != nullptr, topSort,
105105
// TODO: compute the map and pass it to loop emitter directly instead of
106106
// passing in a callback.
107-
[this](TensorId t, Level lvl) -> std::vector<std::pair<TensorId, Level>> {
108-
// Translates from a list of loop index to a list of [tid, dim] pair.
109-
std::vector<LoopId> rLoops = this->merger().getDependentLoops(t, lvl);
110-
std::vector<std::pair<TensorId, Level>> ret;
107+
/*dependentLvlGetter=*/
108+
[this](TensorId t,
109+
Level lvl) -> std::vector<std::pair<TensorLevel, unsigned>> {
110+
// Translates from a list of loop indices to a list of [tid, lvl] pair.
111+
std::vector<LoopCoeffPair> &rLoops = merger().getDependentLoops(t, lvl);
112+
std::vector<std::pair<TensorLevel, unsigned>> ret;
111113
ret.reserve(rLoops.size());
112-
for (LoopId l : rLoops)
113-
ret.emplace_back(this->merger().getLoopDefiningLvl(l));
114+
for (auto [loop, coeff] : rLoops) {
115+
TensorLevel tl = makeTensorLevel(merger().getLoopDefiningLvl(loop));
116+
ret.emplace_back(tl, coeff);
117+
};
114118
return ret;
115119
});
116120
}

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ class CodegenEnv {
9696
loopEmitter.getSynTensorId() == latticeMerger.getSynTensorID());
9797
return loopEmitter.makeTensorLevel(t, l);
9898
}
99+
TensorLevel makeTensorLevel(std::pair<TensorId, Level> tlPair) const {
100+
return makeTensorLevel(tlPair.first, tlPair.second);
101+
}
99102
std::pair<TensorId, Level> unpackTensorLevel(TensorLevel tl) const {
100103
return loopEmitter.unpackTensorLevel(tl);
101104
}

0 commit comments

Comments
 (0)