@@ -148,23 +148,60 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
148
148
// Helper functions that load/store into the position buffer for slice-driven
149
149
// loops.
150
150
// The sliced pointer buffer is orgnized as:
151
- // [size, curPtr] (two metadata) + [[pLo, pHi, pNext], ...] (list of tuples)
151
+ // [size, curPtr] (two metadata) + [[pLo0, pLo1, pLo2, ...],
152
+ // [pHi0, pHi1, pHi2, ...],
153
+ // [pNx0, pNx1, pNx2, ...]]
154
+ static Value allocSlicePosBuf (OpBuilder &builder, Location loc,
155
+ Value tupleCnt) {
156
+ Value bufSz = MULI (tupleCnt, C_IDX (kSliceIterWidth ));
157
+ // Additional two metadata {memSize, idx} at head.
158
+ bufSz = ADDI (bufSz, C_IDX (2 ));
159
+ return genAlloca (builder, loc, bufSz, builder.getIndexType ());
160
+ }
161
+ // TODO: We should use SSA value for it.
162
+ // Gets and sets metadata.
152
163
static Value loadSlicePosPtr (OpBuilder &builder, Location loc, Value sPosBuf ) {
153
- // Load curPtr.
154
- // TODO: We should use SSA value for it.
155
164
return genIndexLoad (builder, loc, sPosBuf , C_IDX (1 ));
156
165
}
157
166
static void updateSlicePosPtr (OpBuilder &builder, Location loc, Value sPosBuf ,
158
167
Value pPtr) {
159
- // Set curPtr.
160
- // TODO: We should use SSA value for it.
161
168
builder.create <memref::StoreOp>(loc, pPtr, sPosBuf , C_IDX (1 ));
162
169
}
163
- static Value loadSliceNextPosPtrStart (OpBuilder &builder, Location loc,
164
- Value sPosBuf , Value tupleIdx) {
165
- // load the pNext in the current tuple specified by `tupleIdx`.
166
- // 4 = 2 (two metadata) + 2 (pNext == tuple[2])
167
- return genIndexLoad (builder, loc, sPosBuf , ADDI (tupleIdx, C_IDX (4 )));
170
+ static Value loadSlicePosTupleNum (OpBuilder &builder, Location loc,
171
+ Value sPosBuf ) {
172
+ return genIndexLoad (builder, loc, sPosBuf , C_IDX (0 ));
173
+ }
174
+ static void updateSlicePosTupleNum (OpBuilder &builder, Location loc, Value num,
175
+ Value sPosBuf ) {
176
+ builder.create <memref::StoreOp>(loc, num, sPosBuf , C_IDX (0 ));
177
+ }
178
+
179
+ // Gets and sets position values for slice-driven loops.
180
+ enum class SlicePosKind { kLo , kHi , kNext };
181
+ static Value getSlicePosIdx (OpBuilder &builder, Location loc, Value posBuf,
182
+ Value tupleIdx, SlicePosKind posKind) {
183
+ Value dim = builder.create <memref::DimOp>(loc, posBuf, C_IDX (0 ));
184
+ Value tupleCnt = DIVUI (SUBI (dim, C_IDX (2 )), C_IDX (kSliceIterWidth ));
185
+ switch (posKind) {
186
+ case SlicePosKind::kLo :
187
+ return ADDI (tupleIdx, C_IDX (2 ));
188
+ case SlicePosKind::kHi :
189
+ return ADDI (tupleIdx, ADDI (tupleCnt, C_IDX (2 )));
190
+ case SlicePosKind::kNext :
191
+ return ADDI (tupleIdx, ADDI (tupleCnt, ADDI (tupleCnt, C_IDX (2 ))));
192
+ }
193
+ llvm_unreachable (" unexpected kind" );
194
+ }
195
+ static Value loadSlicePos (OpBuilder &builder, Location loc, Value sPosBuf ,
196
+ Value tupleIdx, SlicePosKind posKind) {
197
+ return genIndexLoad (builder, loc, sPosBuf ,
198
+ getSlicePosIdx (builder, loc, sPosBuf , tupleIdx, posKind));
199
+ }
200
+ static void updateSlicePos (OpBuilder &builder, Location loc, Value sPosBuf ,
201
+ Value pos, Value tupleIdx, SlicePosKind posKind) {
202
+ builder.create <memref::StoreOp>(
203
+ loc, pos, sPosBuf ,
204
+ getSlicePosIdx (builder, loc, sPosBuf , tupleIdx, posKind));
168
205
}
169
206
170
207
std::pair<Value, Value>
@@ -1445,13 +1482,13 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
1445
1482
// The first compressed level, setting up the position pointer for it.
1446
1483
Value sPosBuf = slicePosBuffer[tid][curLvl].back ();
1447
1484
// One step forwards in the parent level result in forwarding one `segment`
1448
- // (kSliceIterWidth) in the child sparse level.
1449
- Value fPosPtr = MULI (fcnt, C_IDX (kSliceIterWidth )); // forward ptr
1485
+ // in the child sparse level.
1450
1486
Value pPosPtr = loadSlicePosPtr (builder, loc, sPosBuf ); // previous ptr
1451
- Value cPosPtr = ADDI (fPosPtr , pPosPtr); // current ptr
1487
+ Value cPosPtr = ADDI (fcnt , pPosPtr); // current ptr
1452
1488
updateSlicePosPtr (builder, loc, sPosBuf , cPosPtr);
1453
1489
// Loads the position pointer start for next level.
1454
- nxPosPtr = loadSliceNextPosPtrStart (builder, loc, sPosBuf , cPosPtr);
1490
+ nxPosPtr =
1491
+ loadSlicePos (builder, loc, sPosBuf , cPosPtr, SlicePosKind::kNext );
1455
1492
curLvl++;
1456
1493
}
1457
1494
@@ -1463,10 +1500,10 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
1463
1500
for (; curLvl < leafLvl; curLvl++) {
1464
1501
assert (nxPosPtr);
1465
1502
if (!isDenseLT (lvlTypes[tid][curLvl])) {
1466
- nxPosPtr = MULI (nxPosPtr, C_IDX (kSliceIterWidth ));
1467
1503
Value sPosBuf = slicePosBuffer[tid][curLvl].back ();
1468
1504
updateSlicePosPtr (builder, loc, sPosBuf , nxPosPtr);
1469
- nxPosPtr = loadSliceNextPosPtrStart (builder, loc, sPosBuf , nxPosPtr);
1505
+ nxPosPtr =
1506
+ loadSlicePos (builder, loc, sPosBuf , nxPosPtr, SlicePosKind::kNext );
1470
1507
}
1471
1508
}
1472
1509
}
@@ -1736,7 +1773,7 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
1736
1773
std::optional<std::pair<TensorId, Level>> firstResLvl, ValueRange userReduc,
1737
1774
LoopBodyBuilder bodyBuilder) {
1738
1775
1739
- Value c0 = C_IDX (0 ), c1 = C_IDX (1 ), c2 = C_IDX ( 2 ) ;
1776
+ Value c0 = C_IDX (0 ), c1 = C_IDX (1 );
1740
1777
Value pos = c0;
1741
1778
OpBuilder::InsertPoint ip;
1742
1779
SmallVector<Value> innerArgs (userReduc.begin (), userReduc.end ());
@@ -1769,20 +1806,22 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
1769
1806
unsigned depth = frontSlice.depth - 1 ;
1770
1807
Value offset = frontSlice.offset ;
1771
1808
Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
1772
- Value mSz = genIndexLoad (builder, loc, sPtrBuf , c0); // memSize
1809
+ Value mSz = loadSlicePosTupleNum (builder, loc, sPtrBuf );
1773
1810
outerMost = builder.create <scf::ForOp>(
1774
- loc, c2 , mSz , C_IDX ( kSliceIterWidth ) , innerArgs,
1775
- [this , c1, c2, tid, firstLvl, offset, sPtrBuf , &ip, &pos,
1811
+ loc, c0 , mSz , c1 , innerArgs,
1812
+ [this , tid, firstLvl, offset, sPtrBuf , &ip, &pos,
1776
1813
&innerArgs](OpBuilder &builder, Location loc, Value iv,
1777
1814
ValueRange iterArgs) {
1778
1815
// generate traversal for each level.
1779
- Value loopLo = genIndexLoad (builder, loc, sPtrBuf , iv);
1780
- Value loopHi = genIndexLoad (builder, loc, sPtrBuf , ADDI (iv, c1));
1816
+ Value loopLo =
1817
+ loadSlicePos (builder, loc, sPtrBuf , iv, SlicePosKind::kLo );
1818
+ Value loopHi =
1819
+ loadSlicePos (builder, loc, sPtrBuf , iv, SlicePosKind::kHi );
1781
1820
// We need to remember the starting index for next level's
1782
1821
// position, because slice-driven loop breaks the level into
1783
1822
// non-consecutive segments.
1784
- builder. create <memref::StoreOp>( loc, iterArgs.back (), sPtrBuf ,
1785
- ADDI (iv, c2). getResult () );
1823
+ updateSlicePos (builder, loc, sPtrBuf , iterArgs.back (), iv ,
1824
+ SlicePosKind:: kNext );
1786
1825
1787
1826
auto [size, stride] = sliceMeta[tid][firstLvl].back ();
1788
1827
assert (stride == 1 && " Not yet implemented" );
@@ -1873,8 +1912,7 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
1873
1912
1874
1913
void LoopEmitter::genResolvedSliceBegin (OpBuilder &builder, Location loc,
1875
1914
TensorId tid, Level lvl) {
1876
- Value c0 = C_IDX (0 ), c1 = C_IDX (1 ), c2 = C_IDX (2 ), c3 = C_IDX (3 ),
1877
- c4 = C_IDX (4 );
1915
+ Value c0 = C_IDX (0 ), c1 = C_IDX (1 );
1878
1916
if (isDenseLT (lvlTypes[tid][lvl])) {
1879
1917
// Dense slice begin is trivial.
1880
1918
sliceStack[tid].emplace_back (/* minCoord=*/ c0, /* offset=*/ c0,
@@ -1896,10 +1934,10 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
1896
1934
ADDI (posits[tid][lvl - 1 ], c1));
1897
1935
}
1898
1936
// Fills out pIdxBuffer[tid][lvl][0] with [/*memSize =*/4, 0, pLo, pHi]
1899
- builder. create <memref::StoreOp>(loc, c4, sPtrBuf , c0); // memSize = 4
1900
- builder. create <memref::StoreOp>(loc, c0 , sPtrBuf , c1); // index = 0
1901
- builder. create <memref::StoreOp>( loc, pLo, sPtrBuf , c2); // pLo
1902
- builder. create <memref::StoreOp>( loc, pHi, sPtrBuf , c3); // pHi
1937
+ updateSlicePosTupleNum (builder, loc, c1, sPtrBuf );
1938
+ updateSlicePosPtr (builder, loc , sPtrBuf , c0);
1939
+ updateSlicePos (builder, loc, sPtrBuf , pLo, c0, SlicePosKind:: kLo );
1940
+ updateSlicePos (builder, loc, sPtrBuf , pHi, c0, SlicePosKind:: kHi );
1903
1941
1904
1942
// This is an non empty tensor if pLo < pHi.
1905
1943
Value isNonEmpty = CMPI (ult, pLo, pHi);
@@ -1938,7 +1976,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
1938
1976
// }
1939
1977
void LoopEmitter::genUnResolvedSliceBegin (OpBuilder &builder, Location loc,
1940
1978
TensorId tid, Level lvl) {
1941
- Value c0 = C_IDX (0 ), c1 = C_IDX (1 ), c2 = C_IDX ( 2 ) ;
1979
+ Value c0 = C_IDX (0 ), c1 = C_IDX (1 );
1942
1980
unsigned depth = levelReducedDep[tid][lvl];
1943
1981
// The remaining slice size after reduction.
1944
1982
Value remSz = sliceMeta[tid][lvl][depth + 1 ].first ;
@@ -1983,7 +2021,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
1983
2021
SmallVector<Value, 3 > reduc = {
1984
2022
constantI1 (builder, loc, false ), // isNonEmpty
1985
2023
lvlSizes[tid][lvl], // minCoord
1986
- c2 , // memSize
2024
+ c0 , // memSize
1987
2025
};
1988
2026
1989
2027
ValueRange result = genUnResolvedSliceTreeTraverse (
@@ -1992,7 +2030,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
1992
2030
MutableArrayRef<Value> reduc) {
1993
2031
Value &nonEmpty = reduc[0 ];
1994
2032
Value &minCrd = reduc[1 ];
1995
- Value &curMemSz = reduc[2 ];
2033
+ Value &curTupleCnt = reduc[2 ];
1996
2034
1997
2035
Value pHi = ADDI (iv, c1);
1998
2036
Value sPLo = genIndexLoad (builder, loc, positionsBuffers[tid][lvl], iv);
@@ -2024,28 +2062,26 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
2024
2062
YIELD (minCrd);
2025
2063
}
2026
2064
minCrd = ifNonEmpty.getResult (0 );
2027
- builder. create <memref::StoreOp>(loc, sPLo , sPtrBuf , curMemSz);
2028
- Value nxtMemSize = ADDI (curMemSz, c1 );
2029
- builder. create <memref::StoreOp>(loc, sPHi , sPtrBuf , nxtMemSize);
2030
- // curMemSize += kSliceIterWidth
2031
- curMemSz = ADDI (curMemSz , C_IDX (kSliceIterWidth ));
2065
+ updateSlicePos (builder, loc , sPtrBuf , sPLo , curTupleCnt,
2066
+ SlicePosKind:: kLo );
2067
+ updateSlicePos (builder, loc , sPtrBuf , sPHi , curTupleCnt,
2068
+ SlicePosKind:: kHi );
2069
+ curTupleCnt = ADDI (curTupleCnt , C_IDX (1 ));
2032
2070
});
2033
2071
2034
2072
Value isNonEmpty = result[0 ];
2035
2073
Value minCrd = result[1 ];
2036
2074
// Two metadata [memSize, idx].
2037
2075
// TODO: Can use an SSA value for these two metadata
2038
- builder. create <memref::StoreOp>( loc, result[2 ], sPtrBuf , c0 );
2039
- builder. create <memref::StoreOp>(loc, c0 , sPtrBuf , c1 );
2076
+ updateSlicePosTupleNum (builder, loc, result[2 ], sPtrBuf );
2077
+ updateSlicePosPtr (builder, loc , sPtrBuf , c0 );
2040
2078
// FIXME: we need the relative offset related to the base slice.
2041
2079
Value absOffset = offsetFromMinCoord (builder, loc, minCrd, remSz, isNonEmpty);
2042
2080
sliceStack[tid].emplace_back (minCrd, absOffset, isNonEmpty, lvl, depth + 1 );
2043
2081
}
2044
2082
2045
2083
bool LoopEmitter::genSliceBegin (OpBuilder &builder, Location loc, TensorId tid,
2046
2084
Level lvl) {
2047
- Value c1 = C_IDX (1 ), c2 = C_IDX (2 );
2048
-
2049
2085
if (depFullyReduced (tid, lvl)) {
2050
2086
// Do not need to prepare for slice driven loop on dense level after it is
2051
2087
// fully reduced.
@@ -2054,14 +2090,12 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
2054
2090
// If constraints on the tensor is fully resolved. We do not need to
2055
2091
// generates slice begin any more, instead we fall back to TACO-based
2056
2092
// algorithm to (co)iterates over the slice.
2057
- Value pLoPtr =
2058
- loadSlicePosPtr (builder, loc, slicePosBuffer[tid][lvl].back ());
2059
- pLoPtr = ADDI (pLoPtr, c2);
2060
- Value pHiPtr = ADDI (pLoPtr, c1);
2093
+ Value sPosBuf = slicePosBuffer[tid][lvl].back ();
2094
+ Value tupleIdx = loadSlicePosPtr (builder, loc, sPosBuf );
2061
2095
posits[tid][lvl] =
2062
- genIndexLoad (builder, loc, slicePosBuffer[tid][lvl]. back (), pLoPtr );
2096
+ loadSlicePos (builder, loc, sPosBuf , tupleIdx, SlicePosKind:: kLo );
2063
2097
highs[tid][lvl] =
2064
- genIndexLoad (builder, loc, slicePosBuffer[tid][lvl]. back (), pHiPtr );
2098
+ loadSlicePos (builder, loc, sPosBuf , tupleIdx, SlicePosKind:: kHi );
2065
2099
return true ;
2066
2100
}
2067
2101
@@ -2090,8 +2124,7 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
2090
2124
// The buffer can be reused, and the size is loop invariant: it only
2091
2125
// depends on the iteration graph's toposort.
2092
2126
builder.setInsertionPointAfter (localInsertPos);
2093
- Value bufSize = C_IDX (1 );
2094
- Value c2 = C_IDX (2 );
2127
+ Value tupleCnt = C_IDX (1 );
2095
2128
// Accumlates the size required to cache the pLo for the slice.
2096
2129
// E.g., if we want to cache the pIdx for slice<d0xd1xf64> on the second
2097
2130
// level. We at most need to a memref<d0xindex>.
@@ -2108,16 +2141,10 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
2108
2141
assert (!sliceMeta[tid][curLevel - 1 ].empty ());
2109
2142
auto [sz, stride] = sliceMeta[tid][curLevel - 1 ].back ();
2110
2143
assert (stride == 1 && " Not yet implemented" );
2111
- bufSize = MULI (bufSize , sz);
2144
+ tupleCnt = MULI (tupleCnt , sz);
2112
2145
}
2113
- // For a triple of [pLo, pHi, pPtr]. Note that we can not compress pHi
2114
- // because slice creates segments in the index buffer so that the pHi for
2115
- // the current level is no longer the pLo for the next level.
2116
- bufSize = MULI (bufSize, C_IDX (kSliceIterWidth ));
2117
- // Additional two metadata {memSize, idx} at head.
2118
- bufSize = ADDI (bufSize, c2);
2119
2146
for (Value &cache : slicePosBuffer[tid][lvl])
2120
- cache = genAlloca (builder, loc, bufSize, builder. getIndexType () );
2147
+ cache = allocSlicePosBuf (builder, loc, tupleCnt );
2121
2148
}
2122
2149
2123
2150
if (sliceInfo.isInitialTensor () ||
@@ -2147,7 +2174,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
2147
2174
llvm_unreachable (" TODO" );
2148
2175
2149
2176
// else generate code to compute next non empty slice.
2150
- Value c0 = C_IDX (0 ), c1 = C_IDX (1 ), c2 = C_IDX ( 2 ) ;
2177
+ Value c0 = C_IDX (0 ), c1 = C_IDX (1 );
2151
2178
2152
2179
SliceInfo &info = sliceStack[tid].back ();
2153
2180
assert (info.slicedOnLvl == lvl);
@@ -2194,24 +2221,24 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
2194
2221
// offset = minCrd - size + 1;
2195
2222
// }
2196
2223
builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
2197
- reduc[2 ] = absOffset; // restore value.
2198
- Value pSt = c2; // pointer starting index
2199
- Value mSz = genIndexLoad (builder, loc, sPtrBuf , c0); // memSize
2200
- reduc[0 ] = lvlSizes[tid][lvl]; // next min coord
2201
- reduc[1 ] = constantI1 (builder, loc, false ); // isNonEmpty
2224
+ reduc[2 ] = absOffset; // restore value.
2225
+ Value mSz = loadSlicePosTupleNum (builder, loc, sPtrBuf ); // memSize
2226
+ reduc[0 ] = lvlSizes[tid][lvl]; // next min coord
2227
+ reduc[1 ] = constantI1 (builder, loc, false ); // isNonEmpty
2202
2228
auto loopArgs = static_cast <ValueRange>(reduc).drop_back ();
2203
2229
auto forOp = scf::buildLoopNest (
2204
- builder, loc, pSt , mSz , C_IDX ( kSliceIterWidth ) , loopArgs,
2230
+ builder, loc, c0 , mSz , c1 , loopArgs,
2205
2231
[this , tid, lvl, c1, sPtrBuf ,
2206
2232
&info](OpBuilder &builder, Location loc, ValueRange ivs,
2207
2233
ValueRange iterArgs) -> scf::ValueVector {
2208
2234
Value curMinCrd = iterArgs[0 ];
2209
2235
Value isNonEmpty = iterArgs[1 ];
2210
2236
2211
2237
Type idxTp = builder.getIndexType ();
2212
- Value pLo = genIndexLoad (builder, loc, sPtrBuf , ivs.front ());
2213
- Value pHi =
2214
- genIndexLoad (builder, loc, sPtrBuf , ADDI (ivs.front (), c1));
2238
+ Value pLo = loadSlicePos (builder, loc, sPtrBuf , ivs.front (),
2239
+ SlicePosKind::kLo );
2240
+ Value pHi = loadSlicePos (builder, loc, sPtrBuf , ivs.front (),
2241
+ SlicePosKind::kHi );
2215
2242
//
2216
2243
// if (pLo < pHi) // Only loads when inbound.
2217
2244
// coord = load[pLo]
@@ -2235,8 +2262,8 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
2235
2262
&ifEqual.getThenRegion ().front ());
2236
2263
Value newPlo = ADDI (pLo, c1);
2237
2264
// Updates the cache.
2238
- builder. create <memref::StoreOp>( loc, newPlo, sPtrBuf ,
2239
- ivs. front () );
2265
+ updateSlicePos (builder, loc, sPtrBuf , newPlo, ivs. front () ,
2266
+ SlicePosKind:: kLo );
2240
2267
YIELD (newPlo);
2241
2268
}
2242
2269
/* else coord != minCrd */ {
0 commit comments