@@ -56,6 +56,13 @@ using LatPointId = unsigned;
56
56
// / for the corresponding `SmallVector<LatPointId>` object.
57
57
using LatSetId = unsigned ;
58
58
59
+ // / A pair of level and its corresponding DimLevelType of a tensor.
60
+ using LvlDLTPair = std::pair<Level, DimLevelType>;
61
+
62
+ // / A pair of loop id and its coefficients. E.g., for affine expression in the
63
+ // / affine map `2 * d0`, loop id = 0, coefficient = 2.
64
+ using LoopCoeffPair = std::pair<LoopId, unsigned >;
65
+
59
66
// / Tensor expression. Represents an MLIR expression in tensor index notation.
60
67
struct TensorExp final {
61
68
enum class Kind ;
@@ -509,22 +516,22 @@ class Merger {
509
516
510
517
// / Establishes the two-way map that i <-> <t, lvl, dlt>.
511
518
void setLoopDependentTensorLevel (LoopId i, TensorId t, Level lvl,
512
- DimLevelType dlt) {
519
+ DimLevelType dlt, unsigned coefficient ) {
513
520
assert (isValidLoopId (i) && isValidLevel (t, lvl));
514
- assert (!loopToDependencies [i][t].has_value ()); // must be the first def
515
- loopToDependencies [i][t] = std::make_pair (lvl, dlt);
516
- levelToDependentLoop[t][lvl].push_back (i );
521
+ assert (!loopToUnresolvedLvls [i][t].has_value ()); // must be the first def
522
+ loopToUnresolvedLvls [i][t] = std::make_pair (lvl, dlt);
523
+ levelToDependentLoop[t][lvl].emplace_back (i, coefficient );
517
524
}
518
525
519
526
// / Whether the loop has dependent slice.
520
527
bool hasDependentLvl (LoopId i, TensorId t) {
521
528
assert (isValidTensorId (t) && isValidLoopId (i));
522
- return loopToDependencies [i][t].has_value ();
529
+ return loopToUnresolvedLvls [i][t].has_value ();
523
530
}
524
531
525
532
// / Returns the list of loop indices which appear in the non-trivial index
526
533
// / expression on t_l, e.g., A[i+j] => {i, j}
527
- std::vector<LoopId > &getDependentLoops (TensorId t, Level lvl) {
534
+ std::vector<LoopCoeffPair > &getDependentLoops (TensorId t, Level lvl) {
528
535
assert (isValidLevel (t, lvl));
529
536
return levelToDependentLoop[t][lvl];
530
537
}
@@ -541,7 +548,7 @@ class Merger {
541
548
const TensorId t = tensor (b);
542
549
const LoopId i = loop (b);
543
550
assert (isValidTensorId (t) && isValidLoopId (i));
544
- return loopToDependencies [i][t].has_value ();
551
+ return loopToUnresolvedLvls [i][t].has_value ();
545
552
}
546
553
547
554
// / Checks whether the TensorLoopId represents a sparse tensor level contains
@@ -556,12 +563,12 @@ class Merger {
556
563
557
564
Level getLoopDependentLevel (TensorLoopId b) const {
558
565
assert (isLvlWithNonTrivialIdxExp (b));
559
- return loopToDependencies [loop (b)][tensor (b)]->first ;
566
+ return loopToUnresolvedLvls [loop (b)][tensor (b)]->first ;
560
567
}
561
568
562
569
DimLevelType getLoopDependentLevelType (TensorLoopId b) const {
563
570
assert (isLvlWithNonTrivialIdxExp (b));
564
- return loopToDependencies [loop (b)][tensor (b)]->second ;
571
+ return loopToUnresolvedLvls [loop (b)][tensor (b)]->second ;
565
572
}
566
573
567
574
// / Convenience getters to immediately access the stored nodes.
@@ -715,13 +722,13 @@ class Merger {
715
722
// / It is currently only set for non-trivial index expressions.
716
723
// / E.g., A[i+j] => i and j will have dependencies {A0, dlt(A0)} to indicate
717
724
// / that i and j are used in the non-trivial index expression on A0.
718
- std::vector<std::vector<std::optional<std::pair<Level, DimLevelType>>>>
719
- loopToDependencies;
725
+ std::vector<std::vector<std::optional<LvlDLTPair>>> loopToUnresolvedLvls;
720
726
721
727
// / The inverse map of ldxToDependencies from tensor level -> dependent loop
722
- // / E.g., A[i+j], we have A0 => {i, j}, to indicate that A0 uses both {i, j}
723
- // / to compute its indices.
724
- std::vector<std::vector<std::vector<LoopId>>> levelToDependentLoop;
728
+ // / E.g., A[2i+j], we have A0 => {(2, i), (1, j)}, to indicate that A0 uses
729
+ // / both {i, j} to compute its indices and the coefficients on the loop id are
730
+ // / 2 and 1 respectively.
731
+ std::vector<std::vector<std::vector<LoopCoeffPair>>> levelToDependentLoop;
725
732
726
733
// / Map from a loop to the [tid, lvl] pair that defines the loop boundary.
727
734
std::vector<std::pair<TensorId, Level>> loopBounds;
0 commit comments