@@ -44,23 +44,23 @@ using namespace mlir::sparse_tensor;
44
44
// Sparsifier analysis methods.
45
45
// ===----------------------------------------------------------------------===//
46
46
47
- // / Determines if affine expression is invariant.
48
- static bool isInvariantAffine (AffineExpr a, unsigned loopDepth, LoopId ldx,
49
- bool &isAtLoop) {
47
+ // / Returns true iff affine expression is invariant. Sets the
48
+ // / parameter `isAtLoop` when expression just became invariant.
49
+ static bool isInvariantAffine (AffineExpr a, LoopId at, bool &isAtLoop) {
50
50
switch (a.getKind ()) {
51
51
case AffineExprKind::DimId: {
52
52
const LoopId i = cast<AffineDimExpr>(a).getPosition ();
53
- if (i == ldx ) {
53
+ if (i + 1 == at ) {
54
54
isAtLoop = true ;
55
- return true ; // invariant at given loop
55
+ return true ; // becomes invariant at current loop
56
56
}
57
- return i < loopDepth ; // invariant when already generated
57
+ return i < at ; // invariant when already generated
58
58
}
59
59
case AffineExprKind::Add:
60
60
case AffineExprKind::Mul: {
61
61
auto binOp = cast<AffineBinaryOpExpr>(a);
62
- return isInvariantAffine (binOp.getLHS (), loopDepth, ldx , isAtLoop) &&
63
- isInvariantAffine (binOp.getRHS (), loopDepth, ldx , isAtLoop);
62
+ return isInvariantAffine (binOp.getLHS (), at , isAtLoop) &&
63
+ isInvariantAffine (binOp.getRHS (), at , isAtLoop);
64
64
}
65
65
default : {
66
66
assert (isa<AffineConstantExpr>(a));
@@ -126,23 +126,23 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
126
126
if (coefficient <= 0 )
127
127
return false ;
128
128
129
- const LoopId ldx = merger.makeLoopId (cast<AffineDimExpr>(a).getPosition ());
130
- if (!isUndefLT (merger.getLvlType (tensor, ldx )))
129
+ const LoopId idx = merger.makeLoopId (cast<AffineDimExpr>(a).getPosition ());
130
+ if (!isUndefLT (merger.getLvlType (tensor, idx )))
131
131
return false ; // used more than once, e.g., A[i][i]
132
132
133
133
// TODO: Generalizes the following two cases. A[i] (with trivial index
134
134
// expression) can be treated as a special affine index expression. We do
135
135
// not necessarily need to differentiate them.
136
136
if (!isSubExp) {
137
137
assert (coefficient == 1 );
138
- merger.setLevelAndType (tensor, ldx , lvl, lt);
138
+ merger.setLevelAndType (tensor, idx , lvl, lt);
139
139
}
140
140
141
141
if (isSubExp) {
142
142
// The current loops appears in more than one affine expressions on the
143
143
// same tensor. We can not handle this case. e.g., A[i+j][i+k], `i` is
144
144
// used twice.
145
- if (merger.hasDependentLvl (ldx , tensor)) {
145
+ if (merger.hasDependentLvl (idx , tensor)) {
146
146
// TODO: This can be supported by coiterate slices if the loop idx is
147
147
// appeared on affine index for different tensor, or take slice on
148
148
// multiple dimensions when it is on the same tensor.
@@ -154,7 +154,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
154
154
// else increase min(d0_1, d0_2).
155
155
return false ;
156
156
}
157
- merger.setLoopDependentTensorLevel (ldx , tensor, lvl, lt, coefficient);
157
+ merger.setLoopDependentTensorLevel (idx , tensor, lvl, lt, coefficient);
158
158
}
159
159
return true ;
160
160
}
@@ -613,9 +613,9 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
613
613
if (kind == TensorExp::Kind::kReduce )
614
614
env.startCustomReduc (e); // enter custom
615
615
616
- Value v0, v1;
617
616
// If either lhs/rhs is a synthetic zero, we infer the type for the zero value
618
617
// based on the type of the other operand.
618
+ Value v0, v1;
619
619
if (exp.children .e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
620
620
env.exp (exp.children .e0 ).kind == TensorExp::Kind::kSynZero ) {
621
621
v1 = genExp (env, rewriter, exp.children .e1 );
@@ -655,21 +655,21 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
655
655
656
656
// / Hoists loop invariant tensor loads for which indices have been exhausted.
657
657
static void genInvariants (CodegenEnv &env, OpBuilder &builder, ExprId exp,
658
- LoopId ldx , bool atStart) {
658
+ LoopId at , bool atStart) {
659
659
if (exp == ::mlir::sparse_tensor::detail::kInvalidId )
660
660
return ;
661
661
if (env.exp (exp).kind == TensorExp::Kind::kTensor ) {
662
662
// Inspect tensor indices.
663
- bool isAtLoop = ldx == ::mlir::sparse_tensor::detail::kInvalidId ;
664
663
linalg::GenericOp op = env.op ();
665
664
OpOperand &t = op->getOpOperand (env.exp (exp).tensor );
666
665
const auto map = op.getMatchingIndexingMap (&t);
667
666
const auto stt = getSparseTensorType (t.get ());
668
667
const Level lvlRank = stt.getLvlRank ();
669
668
assert (static_cast <Level>(map.getNumResults ()) == lvlRank);
669
+ bool isAtLoop = at == 0 ; // for scalar tensors
670
670
for (Level l = 0 ; l < lvlRank; l++) {
671
671
const AffineExpr a = map.getResult (l);
672
- if (!isInvariantAffine (a, env. getLoopDepth (), ldx, isAtLoop))
672
+ if (!isInvariantAffine (a, at, /* out */ isAtLoop))
673
673
return ; // still in play
674
674
}
675
675
// All exhausted at this level (isAtLoop denotes exactly at this LoopId).
@@ -705,8 +705,8 @@ static void genInvariants(CodegenEnv &env, OpBuilder &builder, ExprId exp,
705
705
env.startCustomReduc (exp); // enter custom
706
706
const ExprId e0 = env.exp (exp).children .e0 ;
707
707
const ExprId e1 = env.exp (exp).children .e1 ;
708
- genInvariants (env, builder, e0 , ldx , atStart);
709
- genInvariants (env, builder, e1 , ldx , atStart);
708
+ genInvariants (env, builder, e0 , at , atStart);
709
+ genInvariants (env, builder, e1 , at , atStart);
710
710
if (env.exp (exp).kind == TensorExp::Kind::kReduce )
711
711
env.endCustomReduc (); // exit custom
712
712
}
@@ -782,29 +782,28 @@ static bool isParallelFor(CodegenEnv &env, bool isOuter, bool isSparse) {
782
782
783
783
// / Whether or not the current loop being generated should be parallized (if
784
784
// / possible) according to the configuration.
785
- static bool shouldTryParallize (CodegenEnv &env, LoopId ldx, bool isOuter ,
785
+ static bool shouldTryParallize (CodegenEnv &env, LoopId at ,
786
786
ArrayRef<TensorLevel> tidLvls) {
787
787
linalg::GenericOp op = env.op ();
788
788
auto iteratorTypes = op.getIteratorTypesArray ();
789
- bool isSparse = llvm::any_of (tidLvls, [ldx , &env](TensorLevel tidLvl) {
790
- // Queries the LT based on the tensor id and loop idx , as requested by
791
- // `CodegenEnv::lt(TensorId, LoopIdx )`. The returned LT from CodegenEnv
789
+ bool isSparse = llvm::any_of (tidLvls, [at , &env](TensorLevel tidLvl) {
790
+ // Queries the LT based on the tensor and loop id , as requested by
791
+ // `CodegenEnv::lt(TensorId, LoopId )`. The returned LT from CodegenEnv
792
792
// should be consistent with the LT indexed by <TensorId, Level>.
793
- const auto lt = env.lt (env.unpackTensorLevel (tidLvl).first , ldx );
793
+ const auto lt = env.lt (env.unpackTensorLevel (tidLvl).first , at );
794
794
return isCompressedLT (lt) || isSingletonLT (lt);
795
795
});
796
- return isParallelFor (env, isOuter, isSparse);
796
+ return isParallelFor (env, /* isOuter= */ at == 0 , isSparse);
797
797
}
798
798
799
799
// / Emit a loop to coiterate over the list of tensor levels. The generated loop
800
800
// / can either be a for loop or while loop depending on whether there is at most
801
801
// / one sparse level in the list.
802
802
static Operation *genCoIteration (CodegenEnv &env, OpBuilder &builder,
803
- LoopId idx, ArrayRef<TensorLevel> tidLvls,
803
+ ArrayRef<TensorLevel> tidLvls,
804
804
bool tryParallel, bool needsUniv) {
805
805
Operation *loop = *env.genLoopBoundary ([&](MutableArrayRef<Value> reduc) {
806
- // Construct the while-loop with a parameter for each
807
- // index.
806
+ // Construct while-loop with a parameter for each index.
808
807
return env.emitter ().enterCoIterationOverTensorsAtLvls (
809
808
builder, env.op ().getLoc (), tidLvls, reduc, tryParallel,
810
809
/* genDedup=*/ true , needsUniv);
@@ -817,12 +816,12 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
817
816
// / singleton iteration or co-iteration over the given conjunction.
818
817
static Operation *genLoop (CodegenEnv &env, OpBuilder &builder, LoopId at,
819
818
bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
820
- bool tryParallel = shouldTryParallize (env, at, at == 0 , tidLvls);
821
- return genCoIteration (env, builder, at, tidLvls, tryParallel, needsUniv);
819
+ bool tryParallel = shouldTryParallize (env, at, tidLvls);
820
+ return genCoIteration (env, builder, tidLvls, tryParallel, needsUniv);
822
821
}
823
822
824
823
// / Generates the induction structure for a while-loop.
825
- static void finalizeWhileOp (CodegenEnv &env, OpBuilder &builder, LoopId idx,
824
+ static void finalizeWhileOp (CodegenEnv &env, OpBuilder &builder,
826
825
bool needsUniv) {
827
826
Location loc = env.op ().getLoc ();
828
827
// Finalize each else branch of all if statements.
@@ -862,7 +861,7 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder, LoopId idx,
862
861
}
863
862
864
863
// / Generates a single if-statement within a while-loop.
865
- static scf::IfOp genIf (CodegenEnv &env, OpBuilder &builder, LoopId ldx ,
864
+ static scf::IfOp genIf (CodegenEnv &env, OpBuilder &builder, LoopId at ,
866
865
LatPointId p) {
867
866
Location loc = env.op ().getLoc ();
868
867
SmallVector<Type> types;
@@ -880,13 +879,13 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId ldx,
880
879
auto stt = getSparseTensorType (env.op ().getInputs ()[tid]);
881
880
lt = stt.getLvlType (*lvl);
882
881
}
883
- assert (ldx == env.merger ().loop (b));
882
+ assert (at == env.merger ().loop (b));
884
883
Value clause;
885
884
if (isCompressedLT (lt) || isSingletonLT (lt) ||
886
885
isLooseCompressedLT (lt) || is2OutOf4LT (lt)) {
887
886
assert (lvl.has_value ());
888
887
const Value crd = env.emitter ().getCoords ()[tid][*lvl];
889
- const Value lvar = env.getLoopVar (ldx );
888
+ const Value lvar = env.getLoopVar (at );
890
889
clause = builder.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
891
890
crd, lvar);
892
891
} else {
@@ -943,12 +942,12 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
943
942
// / Starts a loop sequence at given level. Returns true if
944
943
// / the universal loop index must be maintained at this level.
945
944
static bool startLoopSeq (CodegenEnv &env, OpBuilder &builder, ExprId exp,
946
- LoopId idx, LoopId ldx , LatSetId lts) {
947
- assert (!env.getLoopVar (idx ));
945
+ LoopId at , LatSetId lts) {
946
+ assert (!env.getLoopVar (at ));
948
947
// Emit invariants at this loop sequence level.
949
- genInvariants (env, builder, exp, ldx , /* atStart=*/ true );
948
+ genInvariants (env, builder, exp, at , /* atStart=*/ true );
950
949
// Emit access pattern expansion for sparse tensor output.
951
- genExpand (env, builder, idx , /* atStart=*/ true );
950
+ genExpand (env, builder, at , /* atStart=*/ true );
952
951
// Emit further intitialization at this loop sequence level.
953
952
const LatPointId l0 = env.set (lts)[0 ];
954
953
bool needsUniv = false ;
@@ -957,7 +956,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
957
956
env.merger ().foreachTensorLoopId (l0, [&](TensorLoopId b, TensorId tid,
958
957
std::optional<Level> lvl,
959
958
LevelType lt, bool isIdxReduc) {
960
- assert (env.merger ().loop (b) == idx );
959
+ assert (env.merger ().loop (b) == at );
961
960
if (isDenseLT (lt) || isUndefLT (lt)) {
962
961
if (tid == env.merger ().getSynTensorID ()) {
963
962
// Needs loop emitter to set up loop bounds for synthetic tensor too if
@@ -988,6 +987,7 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
988
987
return false ;
989
988
}
990
989
990
+ // Generates dense affine address for encoding.
991
991
static void genConstantDenseAddressFromLevel (CodegenEnv &env,
992
992
OpBuilder &builder, TensorId tid,
993
993
Level startLvl) {
@@ -1013,30 +1013,30 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
1013
1013
}
1014
1014
}
1015
1015
1016
+ // We can generate address for constant affine expression before any loops
1017
+ // starting from the first level as they do not depend on any thing.
1018
+ // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1019
+ // levels can be determined before loops.
1016
1020
static void genInitConstantDenseAddress (CodegenEnv &env,
1017
1021
RewriterBase &rewriter) {
1018
- // We can generate address for constant affine expression before any loops
1019
- // starting from the first level as they do not depend on any thing.
1020
- // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1021
- // levels can be determined before loops.
1022
1022
for (TensorId tid = 0 , e = env.op ().getNumDpsInputs (); tid < e; tid++)
1023
1023
genConstantDenseAddressFromLevel (env, rewriter, tid, 0 );
1024
1024
}
1025
1025
1026
1026
// / Return true if the lattices bit can be iterated by a for loop.
1027
1027
static bool translateBitsToTidLvlPairs (
1028
- CodegenEnv &env, LatPointId li, LoopId ldx ,
1028
+ CodegenEnv &env, LatPointId li, LoopId at ,
1029
1029
SmallVectorImpl<TensorLevel> &tidLvls,
1030
1030
SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1031
1031
const BitVector &simple = env.lat (li).simple ;
1032
1032
const TensorId outTid = env.merger ().getOutTensorID ();
1033
- const std::optional<Level> outLvl = env.merger ().getLvl (outTid, ldx );
1033
+ const std::optional<Level> outLvl = env.merger ().getLvl (outTid, at );
1034
1034
1035
1035
unsigned numloopCond = 0 ;
1036
1036
bool hasNonUnique = false ;
1037
- env.merger ().foreachTensorLoopId (li, [&, ldx ](TensorLoopId b, TensorId tid,
1038
- std::optional<Level> lvl,
1039
- LevelType lt, bool isIdxReduc) {
1037
+ env.merger ().foreachTensorLoopId (li, [&, at ](TensorLoopId b, TensorId tid,
1038
+ std::optional<Level> lvl,
1039
+ LevelType lt, bool isIdxReduc) {
1040
1040
if (simple[b]) {
1041
1041
if (isIdxReduc) {
1042
1042
tidLvls.push_back (env.makeTensorLevel (tid, *lvl));
@@ -1089,11 +1089,11 @@ static bool translateBitsToTidLvlPairs(
1089
1089
if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl (l))
1090
1090
continue ;
1091
1091
1092
- // Constant affine expression are handled in genLoop
1092
+ // Constant affine expression are handled in genLoop.
1093
1093
if (!isa<AffineConstantExpr>(exp)) {
1094
1094
bool isAtLoop = false ;
1095
- if ( isInvariantAffine (exp, env.getLoopDepth (), ldx, isAtLoop) &&
1096
- isAtLoop) {
1095
+ assert (at == env.getLoopDepth ());
1096
+ if ( isInvariantAffine (exp, at + 1 , /* out */ isAtLoop) && isAtLoop) {
1097
1097
// If the compound affine is invariant and we are right at the
1098
1098
// level. We need to generate the address according to the
1099
1099
// affine expression. This is also the best place we can do it
@@ -1105,7 +1105,7 @@ static bool translateBitsToTidLvlPairs(
1105
1105
}
1106
1106
});
1107
1107
1108
- if (isDenseLT (env.lt (outTid, ldx ))) {
1108
+ if (isDenseLT (env.lt (outTid, at ))) {
1109
1109
// Note that we generate dense indices of the output tensor
1110
1110
// unconditionally, since they may not appear in the lattice, but may be
1111
1111
// needed for linearized env.
@@ -1131,9 +1131,9 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1131
1131
LatPointId li, bool needsUniv) {
1132
1132
// The set of tensors + lvls to generate loops on
1133
1133
SmallVector<TensorLevel> tidLvls;
1134
+
1134
1135
// The set of dense tensors with non-trivial affine expression that just
1135
- // becomes invariant and the address shall now be generated at the current
1136
- // level.
1136
+ // becomes invariant and the address are generated at the current level.
1137
1137
SmallVector<std::pair<TensorLevel, AffineExpr>> affineTidLvls;
1138
1138
bool isSingleCond =
1139
1139
translateBitsToTidLvlPairs (env, li, at, tidLvls, affineTidLvls);
@@ -1161,59 +1161,56 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1161
1161
1162
1162
// / Ends a single loop in current sequence. Returns new values for needsUniv.
1163
1163
static bool endLoop (CodegenEnv &env, RewriterBase &rewriter, Operation *loop,
1164
- LoopId idx, LatPointId li, bool needsUniv,
1165
- bool isSingleCond) {
1166
-
1164
+ LatPointId li, bool needsUniv, bool isSingleCond) {
1165
+ // Either a for-loop or a while-loop that iterates over a slice.
1167
1166
if (isSingleCond) {
1168
- // Either a for-loop or a while-loop that iterates over a slice.
1169
1167
// Any iteration creates a valid lex insert.
1170
1168
if (env.isReduc () && env.getValidLexInsert ())
1171
1169
env.setValidLexInsert (constantI1 (rewriter, env.op ().getLoc (), true ));
1172
1170
} else if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) {
1173
1171
// End a while-loop.
1174
- finalizeWhileOp (env, rewriter, idx, needsUniv);
1172
+ finalizeWhileOp (env, rewriter, needsUniv);
1175
1173
} else {
1176
1174
needsUniv = false ;
1177
1175
}
1178
-
1179
1176
env.genLoopBoundary ([&](MutableArrayRef<Value> reduc) {
1180
1177
env.emitter ().exitCurrentLoop (rewriter, env.op ().getLoc (), reduc);
1181
1178
return std::nullopt;
1182
1179
});
1183
-
1184
1180
return needsUniv;
1185
1181
}
1186
1182
1187
1183
// / Ends a loop sequence at given level.
1188
1184
static void endLoopSeq (CodegenEnv &env, OpBuilder &builder, unsigned exp,
1189
- unsigned idx, unsigned ldx ) {
1190
- assert (!env.getLoopVar (idx ));
1185
+ unsigned at ) {
1186
+ assert (!env.getLoopVar (at ));
1191
1187
env.emitter ().exitCurrentLoopSeq (builder, env.op ().getLoc ());
1192
1188
// Unmark bookkeeping of invariants and loop index.
1193
- genInvariants (env, builder, exp, ldx , /* atStart=*/ false );
1189
+ genInvariants (env, builder, exp, at , /* atStart=*/ false );
1194
1190
// Finalize access pattern expansion for sparse tensor output.
1195
- genExpand (env, builder, idx , /* atStart=*/ false );
1191
+ genExpand (env, builder, at , /* atStart=*/ false );
1196
1192
}
1197
1193
1198
1194
// / Recursively generates code while computing iteration lattices in order
1199
1195
// / to manage the complexity of implementing co-iteration over unions
1200
1196
// / and intersections of sparse iterations spaces.
1201
1197
static void genStmt (CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1202
1198
LoopId at) {
1199
+ assert (at == env.getLoopDepth ());
1200
+
1203
1201
// At each leaf, assign remaining tensor (sub)expression to output tensor.
1204
1202
if (at == env.getLoopNum ()) {
1205
1203
Value rhs = genExp (env, rewriter, exp);
1206
1204
genTensorStore (env, rewriter, exp, rhs);
1207
1205
return ;
1208
1206
}
1209
1207
1210
- // Construct iteration lattices for current loop index, with L0 at top.
1211
- const LoopId ldx = at == 0 ? sparse_tensor::detail::kInvalidId : at - 1 ;
1208
+ // Construct iteration lattices for current loop index.
1212
1209
const LatSetId lts =
1213
1210
env.merger ().optimizeSet (env.merger ().buildLattices (exp, at));
1214
1211
1215
1212
// Start a loop sequence.
1216
- bool needsUniv = startLoopSeq (env, rewriter, exp, at, ldx, lts);
1213
+ bool needsUniv = startLoopSeq (env, rewriter, exp, at, lts);
1217
1214
1218
1215
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
1219
1216
// We cannot change this to `for (const LatPointId li : env.set(lts))`
@@ -1250,11 +1247,12 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1250
1247
}
1251
1248
1252
1249
// End a loop.
1253
- needsUniv = endLoop (env, rewriter, loop, at, li, needsUniv, isSingleCond);
1250
+ needsUniv = endLoop (env, rewriter, loop, at, needsUniv, isSingleCond);
1254
1251
}
1255
1252
1256
1253
// End a loop sequence.
1257
- endLoopSeq (env, rewriter, exp, at, ldx);
1254
+ endLoopSeq (env, rewriter, exp, at);
1255
+ assert (at == env.getLoopDepth ());
1258
1256
}
1259
1257
1260
1258
// / Converts the result computed by the sparse kernel into the required form.
@@ -1309,6 +1307,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1309
1307
op, " Loops not yet scheduled, try run --sparse-reinterpret-map "
1310
1308
" before sparsification." );
1311
1309
}
1310
+
1312
1311
// Must have been demapped as well if the generic op is sorted.
1313
1312
assert (!hasAnyNonIdentityOperandsOrResults (op));
1314
1313
0 commit comments