@@ -147,40 +147,30 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
147
147
148
148
// Helper functions that load/store into the position buffer for slice-driven
149
149
// loops.
150
- // The sliced pointer buffer is orgnized as:
151
- // [size, curPtr] (two metadata) + [[pLo0, pLo1, pLo2, ...],
152
- // [pHi0, pHi1, pHi2, ...],
153
- // [pNx0, pNx1, pNx2, ...]]
150
+ // The sliced pointer buffer is organized as:
151
+ // [[pLo0, pLo1, pLo2, ...],
152
+ // [pHi0, pHi1, pHi2, ...],
153
+ // [pNx0, pNx1, pNx2, ...]]
154
154
static Value allocSlicePosBuf (OpBuilder &builder, Location loc,
155
155
Value tupleCnt) {
156
156
Value bufSz = MULI (tupleCnt, C_IDX (kSliceIterWidth ));
157
157
// Additional two metadata {memSize, idx} at head.
158
- bufSz = ADDI (bufSz, C_IDX (2 ));
159
158
return genAlloca (builder, loc, bufSz, builder.getIndexType ());
160
159
}
161
- // TODO: We should use SSA value for it.
162
- // Gets and sets metadata.
163
- static Value loadSlicePosPtr (OpBuilder &builder, Location loc, Value sPosBuf ) {
164
- return genIndexLoad (builder, loc, sPosBuf , C_IDX (1 ));
165
- }
166
- static void updateSlicePosPtr (OpBuilder &builder, Location loc, Value sPosBuf ,
167
- Value pPtr) {
168
- builder.create <memref::StoreOp>(loc, pPtr, sPosBuf , C_IDX (1 ));
169
- }
170
160
171
161
// Gets and sets position values for slice-driven loops.
172
162
enum class SlicePosKind { kLo , kHi , kNext };
173
163
static Value getSlicePosIdx (OpBuilder &builder, Location loc, Value posBuf,
174
164
Value tupleIdx, SlicePosKind posKind) {
175
165
Value dim = builder.create <memref::DimOp>(loc, posBuf, C_IDX (0 ));
176
- Value tupleCnt = DIVUI (SUBI ( dim, C_IDX ( 2 )) , C_IDX (kSliceIterWidth ));
166
+ Value tupleCnt = DIVUI (dim, C_IDX (kSliceIterWidth ));
177
167
switch (posKind) {
178
168
case SlicePosKind::kLo :
179
- return ADDI ( tupleIdx, C_IDX ( 2 )) ;
169
+ return tupleIdx;
180
170
case SlicePosKind::kHi :
181
- return ADDI (tupleIdx, ADDI ( tupleCnt, C_IDX ( 2 )) );
171
+ return ADDI (tupleIdx, tupleCnt);
182
172
case SlicePosKind::kNext :
183
- return ADDI (tupleIdx, ADDI (tupleCnt, ADDI (tupleCnt, C_IDX (2 ) )));
173
+ return ADDI (tupleIdx, MULI (tupleCnt, C_IDX (2 )));
184
174
}
185
175
llvm_unreachable (" unexpected kind" );
186
176
}
@@ -344,6 +334,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
344
334
this ->dependentLvlMap .assign (
345
335
numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned >>>());
346
336
this ->slicePosBuffer .assign (numTensors, std::vector<std::vector<Value>>());
337
+ this ->sliceTupleNxStartIdx .assign (numTensors, std::vector<Value>());
338
+ this ->sliceTupleFwdCnt .assign (numTensors, std::vector<Value>());
339
+ this ->trivialSlice .assign (numTensors, std::vector<bool >());
347
340
this ->sliceMeta .assign (
348
341
numTensors, std::vector<std::vector<std::pair<Value, unsigned >>>());
349
342
this ->sliceStack .assign (numTensors, std::vector<SliceInfo>());
@@ -394,6 +387,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
394
387
dependentLvlMap[tid].assign (
395
388
lvlRank, std::vector<std::pair<TensorLevel, unsigned >>());
396
389
slicePosBuffer[tid].assign (lvlRank, std::vector<Value>());
390
+ sliceTupleNxStartIdx[tid].assign (lvlRank, Value ());
391
+ sliceTupleFwdCnt[tid].assign (lvlRank, Value ());
392
+ trivialSlice[tid].assign (lvlRank, false );
397
393
sliceMeta[tid].assign (lvlRank, std::vector<std::pair<Value, unsigned >>());
398
394
sliceStack[tid].emplace_back (/* minCrd=*/ Value (),
399
395
/* offset=*/ Value (), /* isNonEmpty*/ Value (),
@@ -806,6 +802,7 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
806
802
assert (ivs.size () == 1 );
807
803
// Coord is the relative offset related to its parents.
808
804
assert (sliceStack[tid].back ().depth == 1 && " TODO: not yet implement" );
805
+ sliceTupleFwdCnt[tid][lvl] = SUBI (ivs[0 ], posits[tid][lvl]);
809
806
// Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
810
807
Value posit = ivs[0 ];
811
808
Value crdBuf = coordinatesBuffers[tid][lvl];
@@ -1324,6 +1321,12 @@ void LoopEmitter::enterTensorsAtDenseLvls(
1324
1321
} else {
1325
1322
posits[tid][lvl] =
1326
1323
genAddress (builder, loc, tid, lvl, ADDI (info.offset , iv));
1324
+ Value fwdCnt = lvl == 0 || trivialSlice[tid][lvl]
1325
+ ? C_IDX (0 )
1326
+ : sliceTupleFwdCnt[tid][lvl - 1 ];
1327
+ Value sz = sliceMeta[tid][lvl].back ().first ;
1328
+ Value mul = MULI (fwdCnt, sz);
1329
+ sliceTupleFwdCnt[tid][lvl] = ADDI (mul, iv);
1327
1330
}
1328
1331
levelReducedDep[tid][lvl]++;
1329
1332
} else {
@@ -1357,13 +1360,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
1357
1360
assert (isDenseLT (lvlTypes[tid][lvl]));
1358
1361
assert (*info.slicedOnLvl == lvl);
1359
1362
(void )reduced;
1360
- // Resets slices pointers as the resolved slices are invalidated after we
1361
- // moves forward to the next slice.
1362
- invalidateSliceIterIdx (rewriter, loc, tid, lvl);
1363
1363
info.minCrd = info.offset = info.isNonEmpty = Value ();
1364
- } else {
1365
- forwardsReducedSliceLevelTreeIt (rewriter, loc, tid, lvl,
1366
- constantIndex (rewriter, loc, 1 ));
1367
1364
}
1368
1365
levelReducedDep[tid][lvl]--;
1369
1366
}
@@ -1443,54 +1440,6 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
1443
1440
}
1444
1441
}
1445
1442
1446
- void LoopEmitter::forwardsReducedSliceLevelTreeIt (OpBuilder &builder,
1447
- Location loc, TensorId tid,
1448
- Level rootLvl, Value fcnt) {
1449
-
1450
- auto stt = getSparseTensorType (tensors[tid]);
1451
-
1452
- // Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
1453
- // sparse levels (but not resolved). Since we forward an iterator at higher
1454
- // level of the tree, the subtree need to be pruned.
1455
- Level leafLvl = rootLvl + 1 ;
1456
- while (leafLvl < stt.getLvlRank () && depFullyReduced (tid, leafLvl) &&
1457
- !stt.isDenseLvl (leafLvl)) {
1458
- leafLvl++;
1459
- }
1460
-
1461
- Level curLvl = rootLvl + 1 ;
1462
- Value nxPosPtr = nullptr ;
1463
- if (curLvl < leafLvl) {
1464
- assert (!isDenseLT (lvlTypes[tid][curLvl]));
1465
- // The first compressed level, setting up the position pointer for it.
1466
- Value sPosBuf = slicePosBuffer[tid][curLvl].back ();
1467
- // One step forwards in the parent level result in forwarding one `segment`
1468
- // in the child sparse level.
1469
- Value pPosPtr = loadSlicePosPtr (builder, loc, sPosBuf ); // previous ptr
1470
- Value cPosPtr = ADDI (fcnt, pPosPtr); // current ptr
1471
- updateSlicePosPtr (builder, loc, sPosBuf , cPosPtr);
1472
- // Loads the position pointer start for next level.
1473
- nxPosPtr =
1474
- loadSlicePos (builder, loc, sPosBuf , cPosPtr, SlicePosKind::kNext );
1475
- curLvl++;
1476
- }
1477
-
1478
- // TODO: This is not always needed, but we did it unconditionally for now for
1479
- // simplicity.
1480
- // It is only needed when `curLvl` is forwarded without traversing its child
1481
- // level (e.g., the level is in a conjunctive lattices and got pruned), such
1482
- // that the position pointer is not forwarded inside the loop.
1483
- for (; curLvl < leafLvl; curLvl++) {
1484
- assert (nxPosPtr);
1485
- if (!isDenseLT (lvlTypes[tid][curLvl])) {
1486
- Value sPosBuf = slicePosBuffer[tid][curLvl].back ();
1487
- updateSlicePosPtr (builder, loc, sPosBuf , nxPosPtr);
1488
- nxPosPtr =
1489
- loadSlicePos (builder, loc, sPosBuf , nxPosPtr, SlicePosKind::kNext );
1490
- }
1491
- }
1492
- }
1493
-
1494
1443
void LoopEmitter::exitWhileLoop (OpBuilder &builder, Location loc,
1495
1444
MutableArrayRef<Value> reduc) {
1496
1445
const LoopInfo &loopInfo = loopStack.back ();
@@ -1540,13 +1489,6 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
1540
1489
forwarded = CMPI (eq, coords[tid][lvl], iv);
1541
1490
operands.push_back (SELECT (forwarded, nxPos, pos));
1542
1491
}
1543
- {
1544
- OpBuilder::InsertionGuard guard (builder);
1545
- auto ifOp = builder.create <scf::IfOp>(loc, TypeRange{}, forwarded,
1546
- /* else=*/ false );
1547
- builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
1548
- forwardsReducedSliceLevelTreeIt (builder, loc, tid, lvl, one);
1549
- }
1550
1492
// The coordinate is invalid now.
1551
1493
coords[tid][lvl] = nullptr ;
1552
1494
@@ -1916,8 +1858,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
1916
1858
pHi = genIndexLoad (builder, loc, positionsBuffers[tid][lvl],
1917
1859
ADDI (posits[tid][lvl - 1 ], c1));
1918
1860
}
1919
- // Fills out pIdxBuffer[tid][lvl][0] with [0, pLo, pHi]
1920
- updateSlicePosPtr (builder, loc, sPtrBuf , c0);
1861
+ // Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
1921
1862
updateSlicePos (builder, loc, sPtrBuf , pLo, c0, SlicePosKind::kLo );
1922
1863
updateSlicePos (builder, loc, sPtrBuf , pHi, c0, SlicePosKind::kHi );
1923
1864
// Slice over a resolved parent, we only need one pair of pos hi and lo to
@@ -2056,8 +1997,6 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
2056
1997
Value isNonEmpty = result[0 ];
2057
1998
Value minCrd = result[1 ];
2058
1999
// Two metadata [memSize, idx].
2059
- // TODO: Can use an SSA value for these two metadata
2060
- updateSlicePosPtr (builder, loc, sPtrBuf , c0);
2061
2000
// FIXME: we need the relative offset related to the base slice.
2062
2001
Value absOffset = offsetFromMinCoord (builder, loc, minCrd, remSz, isNonEmpty);
2063
2002
sliceStack[tid].emplace_back (minCrd, absOffset, isNonEmpty, result[2 ], lvl,
@@ -2066,16 +2005,30 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
2066
2005
2067
2006
bool LoopEmitter::genSliceBegin (OpBuilder &builder, Location loc, TensorId tid,
2068
2007
Level lvl) {
2008
+ Value curLvlIdx = C_IDX (0 );
2069
2009
if (depFullyReduced (tid, lvl)) {
2070
- // Do not need to prepare for slice driven loop on dense level after it is
2071
- // fully reduced.
2010
+ if (lvl == 0 || trivialSlice[tid][lvl]) {
2011
+ sliceTupleNxStartIdx[tid][lvl] = C_IDX (0 );
2012
+ } else {
2013
+ if (isDenseLT (lvlTypes[tid][lvl])) {
2014
+ sliceTupleNxStartIdx[tid][lvl] = sliceTupleNxStartIdx[tid][lvl - 1 ];
2015
+ } else {
2016
+ assert (isCompressedLT (lvlTypes[tid][lvl]));
2017
+ curLvlIdx = ADDI (sliceTupleNxStartIdx[tid][lvl - 1 ],
2018
+ sliceTupleFwdCnt[0 ][lvl - 1 ]);
2019
+ sliceTupleNxStartIdx[tid][lvl] =
2020
+ loadSlicePos (builder, loc, slicePosBuffer[tid][lvl].back (),
2021
+ curLvlIdx, SlicePosKind::kNext );
2022
+ }
2023
+ }
2072
2024
if (isDenseLT (lvlTypes[tid][lvl]))
2073
2025
return true ;
2026
+
2027
+ Value sPosBuf = slicePosBuffer[tid][lvl].back ();
2074
2028
// If constraints on the tensor is fully resolved. We do not need to
2075
2029
// generates slice begin any more, instead we fall back to TACO-based
2076
2030
// algorithm to (co)iterates over the slice.
2077
- Value sPosBuf = slicePosBuffer[tid][lvl].back ();
2078
- Value tupleIdx = loadSlicePosPtr (builder, loc, sPosBuf );
2031
+ Value tupleIdx = curLvlIdx;
2079
2032
posits[tid][lvl] =
2080
2033
loadSlicePos (builder, loc, sPosBuf , tupleIdx, SlicePosKind::kLo );
2081
2034
highs[tid][lvl] =
@@ -2134,23 +2087,16 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
2134
2087
if (sliceInfo.isInitialTensor () ||
2135
2088
(lvl >= 1 && lvlFullyResolved (tid, lvl - 1 ))) {
2136
2089
// First level or previous level has been full resolved.
2090
+ trivialSlice[tid][lvl] = true ;
2137
2091
genResolvedSliceBegin (builder, loc, tid, lvl);
2138
2092
} else {
2139
2093
// The previous level has not been full resolved.
2094
+ trivialSlice[tid][lvl] = false ;
2140
2095
genUnResolvedSliceBegin (builder, loc, tid, lvl);
2141
2096
}
2142
2097
return false ;
2143
2098
}
2144
2099
2145
- void LoopEmitter::invalidateSliceIterIdx (OpBuilder &builder, Location loc,
2146
- TensorId tid, Level lvl) {
2147
- for (unsigned i = 0 ; i <= lvl; i++) {
2148
- if (!isDenseLT (lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty ()) {
2149
- updateSlicePosPtr (builder, loc, slicePosBuffer[tid][i].back (), C_IDX (0 ));
2150
- }
2151
- }
2152
- }
2153
-
2154
2100
std::tuple<Value, Value, Value>
2155
2101
LoopEmitter::genSliceNextInduction (OpBuilder &builder, Location loc,
2156
2102
TensorId tid, Level lvl) {
@@ -2175,10 +2121,6 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
2175
2121
// isNonEmpty = false;
2176
2122
//
2177
2123
Value absOffset = info.offset ;
2178
- // Resets slices pointers as the resolved slices are invalidated after we
2179
- // moves forward to the next slice.
2180
- invalidateSliceIterIdx (builder, loc, tid, lvl);
2181
-
2182
2124
SmallVector<Value, 3 > reduc = {info.minCrd , info.isNonEmpty , absOffset};
2183
2125
Value sPtrBuf = slicePosBuffer[tid][lvl][info.depth - 1 ];
2184
2126
Value fastPathP = CMPI (ugt, info.minCrd , absOffset);
0 commit comments