@@ -949,94 +949,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
949
949
// Sparsifier synthesis methods (loop sequence).
950
950
// ===----------------------------------------------------------------------===//
951
951
952
- // / Starts a loop sequence at given level. Returns true if
953
- // / the universal loop index must be maintained at this level.
954
- static bool startLoopSeq (CodegenEnv &env, OpBuilder &builder, ExprId exp,
955
- LoopId curr, LatSetId lts) {
956
- assert (!env.getLoopVar (curr));
957
- // Emit invariants at this loop sequence level.
958
- genInvariants (env, builder, exp, curr, /* isStart=*/ true );
959
- // Emit access pattern expansion for sparse tensor output.
960
- genExpand (env, builder, curr, /* isStart=*/ true );
961
- // Emit further intitialization at this loop sequence level.
962
- const LatPointId l0 = env.set (lts)[0 ];
963
- bool needsUniv = false ;
964
-
965
- SmallVector<TensorLevel> tidLvls;
966
- env.merger ().foreachTensorLoopId (l0, [&](TensorLoopId b, TensorId tid,
967
- std::optional<Level> lvl,
968
- LevelType lt, bool isIdxReduc) {
969
- assert (env.merger ().loop (b) == curr);
970
- if (isDenseLT (lt) || isUndefLT (lt)) {
971
- if (tid == env.merger ().getSynTensorID ()) {
972
- // Needs loop emitter to set up loop bounds for synthetic tensor too if
973
- // there is a loop condition imposed on the synthetic tensor.
974
- tidLvls.push_back (env.makeTensorLevel (tid, env.getCurrentDepth ()));
975
- }
976
- needsUniv = true ;
977
- }
978
- if (isCompressedLT (lt) || isSingletonLT (lt) || isLooseCompressedLT (lt) ||
979
- is2OutOf4LT (lt) || isIdxReduc) {
980
- // Only when this is a index reduction loop, can the lt be undefined.
981
- assert (!isUndefLT (lt) || isIdxReduc);
982
- // sparse/singleton levels, or a dense/sparse index reduction loop.
983
- tidLvls.push_back (env.makeTensorLevel (tid, *lvl));
984
- }
985
- });
986
-
987
- env.emitter ().enterNewLoopSeq (builder, env.op ().getLoc (), tidLvls);
988
-
989
- // Maintain the universal index only if it is actually
990
- // consumed by a subsequent lattice point.
991
- if (needsUniv) {
992
- for (const LatPointId li : env.set (lts).drop_front ())
993
- if (!env.merger ().hasAnySparse (env.lat (li).simple ))
994
- return true ;
995
- }
996
- return false ;
997
- }
998
-
999
- // Generates dense affine address for encoding.
1000
- static void genConstantDenseAddressFromLevel (CodegenEnv &env,
1001
- OpBuilder &builder, TensorId tid,
1002
- Level startLvl) {
1003
- // TODO: Handle affine expression on output tensor.
1004
- linalg::GenericOp op = env.op ();
1005
- assert (tid < op.getNumDpsInputs ());
1006
- OpOperand *input = op.getDpsInputOperands ()[tid];
1007
- const auto lvlExprs = op.getMatchingIndexingMap (input).getResults ();
1008
- const auto enc = getSparseTensorEncoding (input->get ().getType ());
1009
- if (enc) {
1010
- const Location loc = op.getLoc ();
1011
- const TensorId tid = env.makeTensorId (input->getOperandNumber ());
1012
- const Level lvlRank = enc.getLvlRank ();
1013
- assert (lvlExprs.size () == static_cast <size_t >(lvlRank));
1014
- for (Level l = startLvl; l < lvlRank; l++) {
1015
- AffineExpr lvlExpr = lvlExprs[l];
1016
- if (enc.isDenseLvl (l) && isa<AffineConstantExpr>(lvlExpr))
1017
- env.emitter ().genDenseAffineAddress (
1018
- builder, loc, env.makeTensorLevel (tid, l), lvlExpr);
1019
- else
1020
- return ; // break on first non-dense non-constant level
1021
- }
1022
- }
1023
- }
1024
-
1025
- // We can generate address for constant affine expression before any loops
1026
- // starting from the first level as they do not depend on any thing.
1027
- // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1028
- // levels can be determined before loops.
1029
- static void genInitConstantDenseAddress (CodegenEnv &env,
1030
- RewriterBase &rewriter) {
1031
- for (TensorId tid = 0 , e = env.op ().getNumDpsInputs (); tid < e; tid++)
1032
- genConstantDenseAddressFromLevel (env, rewriter, tid, 0 );
1033
- }
1034
-
1035
- // / Return true if the lattices bit can be iterated by a for loop.
1036
- static bool translateBitsToTidLvlPairs (
952
+ static bool getAllTidLvlsInLatPoints (
1037
953
CodegenEnv &env, LatPointId li, LoopId curr,
1038
- SmallVectorImpl<TensorLevel> &tidLvls,
1039
- SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
954
+ llvm::function_ref<void (TensorLevel, AffineExpr)> callback) {
1040
955
const BitVector &simple = env.lat (li).simple ;
1041
956
const TensorId outTid = env.merger ().getOutTensorID ();
1042
957
const std::optional<Level> outLvl = env.merger ().getLvl (outTid, curr);
@@ -1048,7 +963,7 @@ static bool translateBitsToTidLvlPairs(
1048
963
LevelType lt, bool isIdxReduc) {
1049
964
if (simple[b]) {
1050
965
if (isIdxReduc) {
1051
- tidLvls. push_back (env.makeTensorLevel (tid, *lvl));
966
+ callback (env.makeTensorLevel (tid, *lvl), nullptr );
1052
967
numloopCond++;
1053
968
return ;
1054
969
}
@@ -1072,10 +987,10 @@ static bool translateBitsToTidLvlPairs(
1072
987
}
1073
988
}
1074
989
hasNonUnique = !isUniqueLT (lt) || hasNonUnique;
1075
- tidLvls. push_back (env.makeTensorLevel (tid, *lvl));
990
+ callback (env.makeTensorLevel (tid, *lvl), nullptr );
1076
991
numloopCond++;
1077
992
} else if (isDenseLT (lt) || isIdxReduc) {
1078
- tidLvls. push_back (env.makeTensorLevel (tid, *lvl));
993
+ callback (env.makeTensorLevel (tid, *lvl), nullptr );
1079
994
} else {
1080
995
assert (isUndefLT (lt));
1081
996
linalg::GenericOp op = env.op ();
@@ -1109,7 +1024,7 @@ static bool translateBitsToTidLvlPairs(
1109
1024
// level. We need to generate the address according to the
1110
1025
// affine expression. This is also the best place we can do it
1111
1026
// to avoid putting it inside inner loops.
1112
- affineTidLvls. emplace_back (env.makeTensorLevel (tid, l), exp);
1027
+ callback (env.makeTensorLevel (tid, l), exp);
1113
1028
}
1114
1029
}
1115
1030
}
@@ -1120,22 +1035,99 @@ static bool translateBitsToTidLvlPairs(
1120
1035
// Note that we generate dense indices of the output tensor
1121
1036
// unconditionally, since they may not appear in the lattice, but may be
1122
1037
// needed for linearized env.
1123
- tidLvls. push_back (env.makeTensorLevel (outTid, *outLvl));
1038
+ callback (env.makeTensorLevel (outTid, *outLvl), nullptr );
1124
1039
}
1125
1040
1126
1041
if (numloopCond == 0 ) {
1127
1042
// Corner cases where the loop bound is defined by a *unused* operand, in
1128
1043
// this case, we just generate a dense "fake" loop by iterating over the
1129
1044
// synthetic tensor.
1130
- tidLvls.push_back (env.makeTensorLevel (env.merger ().getSynTensorID (),
1131
- env.getCurrentDepth ()));
1045
+ callback (env.makeTensorLevel (env.merger ().getSynTensorID (), curr), nullptr );
1132
1046
numloopCond++;
1133
1047
}
1134
1048
// If we just need to one loop conditions and the conditions is not imposed on
1135
1049
// non-unique level, the loop can be generated by a for loop.
1136
1050
return numloopCond == 1 && !hasNonUnique;
1137
1051
}
1138
1052
1053
+ // / Starts a loop sequence at given level. Returns true if
1054
+ // / the universal loop index must be maintained at this level.
1055
+ static bool startLoopSeq (CodegenEnv &env, OpBuilder &builder, ExprId exp,
1056
+ LoopId curr, LatSetId lts) {
1057
+ assert (!env.getLoopVar (curr));
1058
+ // Emit invariants at this loop sequence level.
1059
+ genInvariants (env, builder, exp, curr, /* isStart=*/ true );
1060
+ // Emit access pattern expansion for sparse tensor output.
1061
+ genExpand (env, builder, curr, /* isStart=*/ true );
1062
+ // Emit further initialization at this loop sequence level.
1063
+ const LatPointId l0 = env.set (lts)[0 ];
1064
+
1065
+ SmallVector<TensorLevel> tidLvls;
1066
+ getAllTidLvlsInLatPoints (env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1067
+ tidLvls.emplace_back (tl);
1068
+ });
1069
+
1070
+ env.emitter ().enterNewLoopSeq (builder, env.op ().getLoc (), tidLvls);
1071
+
1072
+ // Maintain the universal index only if it is actually
1073
+ // consumed by a subsequent lattice point.
1074
+ for (const LatPointId li : env.set (lts).drop_front ())
1075
+ if (!env.merger ().hasAnySparse (env.lat (li).simple ))
1076
+ return true ;
1077
+
1078
+ return false ;
1079
+ }
1080
+
1081
+ // Generates dense affine address for encoding.
1082
+ static void genConstantDenseAddressFromLevel (CodegenEnv &env,
1083
+ OpBuilder &builder, TensorId tid,
1084
+ Level startLvl) {
1085
+ // TODO: Handle affine expression on output tensor.
1086
+ linalg::GenericOp op = env.op ();
1087
+ assert (tid < op.getNumDpsInputs ());
1088
+ OpOperand *input = op.getDpsInputOperands ()[tid];
1089
+ const auto lvlExprs = op.getMatchingIndexingMap (input).getResults ();
1090
+ const auto enc = getSparseTensorEncoding (input->get ().getType ());
1091
+ if (enc) {
1092
+ const Location loc = op.getLoc ();
1093
+ const TensorId tid = env.makeTensorId (input->getOperandNumber ());
1094
+ const Level lvlRank = enc.getLvlRank ();
1095
+ assert (lvlExprs.size () == static_cast <size_t >(lvlRank));
1096
+ for (Level l = startLvl; l < lvlRank; l++) {
1097
+ AffineExpr lvlExpr = lvlExprs[l];
1098
+ if (enc.isDenseLvl (l) && isa<AffineConstantExpr>(lvlExpr))
1099
+ env.emitter ().genDenseAffineAddress (
1100
+ builder, loc, env.makeTensorLevel (tid, l), lvlExpr);
1101
+ else
1102
+ return ; // break on first non-dense non-constant level
1103
+ }
1104
+ }
1105
+ }
1106
+
1107
+ // We can generate address for constant affine expression before any loops
1108
+ // starting from the first level as they do not depend on anything.
1109
+ // E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
1110
+ // levels can be determined before loops.
1111
+ static void genInitConstantDenseAddress (CodegenEnv &env,
1112
+ RewriterBase &rewriter) {
1113
+ for (TensorId tid = 0 , e = env.op ().getNumDpsInputs (); tid < e; tid++)
1114
+ genConstantDenseAddressFromLevel (env, rewriter, tid, 0 );
1115
+ }
1116
+
1117
+ // / Returns true if the lattice bit can be iterated by a for loop.
1118
+ static bool translateBitsToTidLvlPairs (
1119
+ CodegenEnv &env, LatPointId li, LoopId curr,
1120
+ SmallVectorImpl<TensorLevel> &tidLvls,
1121
+ SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
1122
+ return getAllTidLvlsInLatPoints (env, li, curr,
1123
+ [&](TensorLevel tl, AffineExpr exp) {
1124
+ if (exp)
1125
+ affineTidLvls.emplace_back (tl, exp);
1126
+ else
1127
+ tidLvls.emplace_back (tl);
1128
+ });
1129
+ }
1130
+
1139
1131
// / Starts a single loop in current sequence.
1140
1132
static std::pair<Operation *, bool > startLoop (CodegenEnv &env,
1141
1133
OpBuilder &builder, LoopId curr,
0 commit comments