Skip to content

Commit b6cad75

Browse files
authored
[mlir][sparse] refactoring: using util functions to query the index to load from position array for slice-driven loop. (#73986)
1 parent 7ae1b76 commit b6cad75

File tree

3 files changed

+308
-274
lines changed

3 files changed

+308
-274
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 97 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -148,23 +148,60 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
148148
// Helper functions that load/store into the position buffer for slice-driven
149149
// loops.
150150
// The sliced pointer buffer is orgnized as:
151-
// [size, curPtr] (two metadata) + [[pLo, pHi, pNext], ...] (list of tuples)
151+
// [size, curPtr] (two metadata) + [[pLo0, pLo1, pLo2, ...],
152+
// [pHi0, pHi1, pHi2, ...],
153+
// [pNx0, pNx1, pNx2, ...]]
154+
static Value allocSlicePosBuf(OpBuilder &builder, Location loc,
155+
Value tupleCnt) {
156+
Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
157+
// Additional two metadata {memSize, idx} at head.
158+
bufSz = ADDI(bufSz, C_IDX(2));
159+
return genAlloca(builder, loc, bufSz, builder.getIndexType());
160+
}
161+
// TODO: We should use SSA value for it.
162+
// Gets and sets metadata.
152163
static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) {
153-
// Load curPtr.
154-
// TODO: We should use SSA value for it.
155164
return genIndexLoad(builder, loc, sPosBuf, C_IDX(1));
156165
}
157166
static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
158167
Value pPtr) {
159-
// Set curPtr.
160-
// TODO: We should use SSA value for it.
161168
builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
162169
}
163-
static Value loadSliceNextPosPtrStart(OpBuilder &builder, Location loc,
164-
Value sPosBuf, Value tupleIdx) {
165-
// load the pNext in the current tuple specified by `tupleIdx`.
166-
// 4 = 2 (two metadata) + 2 (pNext == tuple[2])
167-
return genIndexLoad(builder, loc, sPosBuf, ADDI(tupleIdx, C_IDX(4)));
170+
static Value loadSlicePosTupleNum(OpBuilder &builder, Location loc,
171+
Value sPosBuf) {
172+
return genIndexLoad(builder, loc, sPosBuf, C_IDX(0));
173+
}
174+
static void updateSlicePosTupleNum(OpBuilder &builder, Location loc, Value num,
175+
Value sPosBuf) {
176+
builder.create<memref::StoreOp>(loc, num, sPosBuf, C_IDX(0));
177+
}
178+
179+
// Gets and sets position values for slice-driven loops.
180+
enum class SlicePosKind { kLo, kHi, kNext };
181+
static Value getSlicePosIdx(OpBuilder &builder, Location loc, Value posBuf,
182+
Value tupleIdx, SlicePosKind posKind) {
183+
Value dim = builder.create<memref::DimOp>(loc, posBuf, C_IDX(0));
184+
Value tupleCnt = DIVUI(SUBI(dim, C_IDX(2)), C_IDX(kSliceIterWidth));
185+
switch (posKind) {
186+
case SlicePosKind::kLo:
187+
return ADDI(tupleIdx, C_IDX(2));
188+
case SlicePosKind::kHi:
189+
return ADDI(tupleIdx, ADDI(tupleCnt, C_IDX(2)));
190+
case SlicePosKind::kNext:
191+
return ADDI(tupleIdx, ADDI(tupleCnt, ADDI(tupleCnt, C_IDX(2))));
192+
}
193+
llvm_unreachable("unexpected kind");
194+
}
195+
static Value loadSlicePos(OpBuilder &builder, Location loc, Value sPosBuf,
196+
Value tupleIdx, SlicePosKind posKind) {
197+
return genIndexLoad(builder, loc, sPosBuf,
198+
getSlicePosIdx(builder, loc, sPosBuf, tupleIdx, posKind));
199+
}
200+
static void updateSlicePos(OpBuilder &builder, Location loc, Value sPosBuf,
201+
Value pos, Value tupleIdx, SlicePosKind posKind) {
202+
builder.create<memref::StoreOp>(
203+
loc, pos, sPosBuf,
204+
getSlicePosIdx(builder, loc, sPosBuf, tupleIdx, posKind));
168205
}
169206

170207
std::pair<Value, Value>
@@ -1445,13 +1482,13 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
14451482
// The first compressed level, setting up the position pointer for it.
14461483
Value sPosBuf = slicePosBuffer[tid][curLvl].back();
14471484
// One step forwards in the parent level result in forwarding one `segment`
1448-
// (kSliceIterWidth) in the child sparse level.
1449-
Value fPosPtr = MULI(fcnt, C_IDX(kSliceIterWidth)); // forward ptr
1485+
// in the child sparse level.
14501486
Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr
1451-
Value cPosPtr = ADDI(fPosPtr, pPosPtr); // current ptr
1487+
Value cPosPtr = ADDI(fcnt, pPosPtr); // current ptr
14521488
updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr);
14531489
// Loads the position pointer start for next level.
1454-
nxPosPtr = loadSliceNextPosPtrStart(builder, loc, sPosBuf, cPosPtr);
1490+
nxPosPtr =
1491+
loadSlicePos(builder, loc, sPosBuf, cPosPtr, SlicePosKind::kNext);
14551492
curLvl++;
14561493
}
14571494

@@ -1463,10 +1500,10 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
14631500
for (; curLvl < leafLvl; curLvl++) {
14641501
assert(nxPosPtr);
14651502
if (!isDenseLT(lvlTypes[tid][curLvl])) {
1466-
nxPosPtr = MULI(nxPosPtr, C_IDX(kSliceIterWidth));
14671503
Value sPosBuf = slicePosBuffer[tid][curLvl].back();
14681504
updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr);
1469-
nxPosPtr = loadSliceNextPosPtrStart(builder, loc, sPosBuf, nxPosPtr);
1505+
nxPosPtr =
1506+
loadSlicePos(builder, loc, sPosBuf, nxPosPtr, SlicePosKind::kNext);
14701507
}
14711508
}
14721509
}
@@ -1736,7 +1773,7 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
17361773
std::optional<std::pair<TensorId, Level>> firstResLvl, ValueRange userReduc,
17371774
LoopBodyBuilder bodyBuilder) {
17381775

1739-
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
1776+
Value c0 = C_IDX(0), c1 = C_IDX(1);
17401777
Value pos = c0;
17411778
OpBuilder::InsertPoint ip;
17421779
SmallVector<Value> innerArgs(userReduc.begin(), userReduc.end());
@@ -1769,20 +1806,22 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
17691806
unsigned depth = frontSlice.depth - 1;
17701807
Value offset = frontSlice.offset;
17711808
Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
1772-
Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
1809+
Value mSz = loadSlicePosTupleNum(builder, loc, sPtrBuf);
17731810
outerMost = builder.create<scf::ForOp>(
1774-
loc, c2, mSz, C_IDX(kSliceIterWidth), innerArgs,
1775-
[this, c1, c2, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
1811+
loc, c0, mSz, c1, innerArgs,
1812+
[this, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
17761813
&innerArgs](OpBuilder &builder, Location loc, Value iv,
17771814
ValueRange iterArgs) {
17781815
// generate traversal for each level.
1779-
Value loopLo = genIndexLoad(builder, loc, sPtrBuf, iv);
1780-
Value loopHi = genIndexLoad(builder, loc, sPtrBuf, ADDI(iv, c1));
1816+
Value loopLo =
1817+
loadSlicePos(builder, loc, sPtrBuf, iv, SlicePosKind::kLo);
1818+
Value loopHi =
1819+
loadSlicePos(builder, loc, sPtrBuf, iv, SlicePosKind::kHi);
17811820
// We need to remember the starting index for next level's
17821821
// position, because slice-driven loop breaks the level into
17831822
// non-consecutive segments.
1784-
builder.create<memref::StoreOp>(loc, iterArgs.back(), sPtrBuf,
1785-
ADDI(iv, c2).getResult());
1823+
updateSlicePos(builder, loc, sPtrBuf, iterArgs.back(), iv,
1824+
SlicePosKind::kNext);
17861825

17871826
auto [size, stride] = sliceMeta[tid][firstLvl].back();
17881827
assert(stride == 1 && "Not yet implemented");
@@ -1873,8 +1912,7 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
18731912

18741913
void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
18751914
TensorId tid, Level lvl) {
1876-
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2), c3 = C_IDX(3),
1877-
c4 = C_IDX(4);
1915+
Value c0 = C_IDX(0), c1 = C_IDX(1);
18781916
if (isDenseLT(lvlTypes[tid][lvl])) {
18791917
// Dense slice begin is trivial.
18801918
sliceStack[tid].emplace_back(/*minCoord=*/c0, /*offset=*/c0,
@@ -1896,10 +1934,10 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
18961934
ADDI(posits[tid][lvl - 1], c1));
18971935
}
18981936
// Fills out pIdxBuffer[tid][lvl][0] with [/*memSize =*/4, 0, pLo, pHi]
1899-
builder.create<memref::StoreOp>(loc, c4, sPtrBuf, c0); // memSize = 4
1900-
builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1); // index = 0
1901-
builder.create<memref::StoreOp>(loc, pLo, sPtrBuf, c2); // pLo
1902-
builder.create<memref::StoreOp>(loc, pHi, sPtrBuf, c3); // pHi
1937+
updateSlicePosTupleNum(builder, loc, c1, sPtrBuf);
1938+
updateSlicePosPtr(builder, loc, sPtrBuf, c0);
1939+
updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
1940+
updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
19031941

19041942
// This is an non empty tensor if pLo < pHi.
19051943
Value isNonEmpty = CMPI(ult, pLo, pHi);
@@ -1938,7 +1976,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
19381976
// }
19391977
void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
19401978
TensorId tid, Level lvl) {
1941-
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
1979+
Value c0 = C_IDX(0), c1 = C_IDX(1);
19421980
unsigned depth = levelReducedDep[tid][lvl];
19431981
// The remaining slice size after reduction.
19441982
Value remSz = sliceMeta[tid][lvl][depth + 1].first;
@@ -1983,7 +2021,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
19832021
SmallVector<Value, 3> reduc = {
19842022
constantI1(builder, loc, false), // isNonEmpty
19852023
lvlSizes[tid][lvl], // minCoord
1986-
c2, // memSize
2024+
c0, // memSize
19872025
};
19882026

19892027
ValueRange result = genUnResolvedSliceTreeTraverse(
@@ -1992,7 +2030,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
19922030
MutableArrayRef<Value> reduc) {
19932031
Value &nonEmpty = reduc[0];
19942032
Value &minCrd = reduc[1];
1995-
Value &curMemSz = reduc[2];
2033+
Value &curTupleCnt = reduc[2];
19962034

19972035
Value pHi = ADDI(iv, c1);
19982036
Value sPLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], iv);
@@ -2024,28 +2062,26 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
20242062
YIELD(minCrd);
20252063
}
20262064
minCrd = ifNonEmpty.getResult(0);
2027-
builder.create<memref::StoreOp>(loc, sPLo, sPtrBuf, curMemSz);
2028-
Value nxtMemSize = ADDI(curMemSz, c1);
2029-
builder.create<memref::StoreOp>(loc, sPHi, sPtrBuf, nxtMemSize);
2030-
// curMemSize += kSliceIterWidth
2031-
curMemSz = ADDI(curMemSz, C_IDX(kSliceIterWidth));
2065+
updateSlicePos(builder, loc, sPtrBuf, sPLo, curTupleCnt,
2066+
SlicePosKind::kLo);
2067+
updateSlicePos(builder, loc, sPtrBuf, sPHi, curTupleCnt,
2068+
SlicePosKind::kHi);
2069+
curTupleCnt = ADDI(curTupleCnt, C_IDX(1));
20322070
});
20332071

20342072
Value isNonEmpty = result[0];
20352073
Value minCrd = result[1];
20362074
// Two metadata [memSize, idx].
20372075
// TODO: Can use an SSA value for these two metadata
2038-
builder.create<memref::StoreOp>(loc, result[2], sPtrBuf, c0);
2039-
builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1);
2076+
updateSlicePosTupleNum(builder, loc, result[2], sPtrBuf);
2077+
updateSlicePosPtr(builder, loc, sPtrBuf, c0);
20402078
// FIXME: we need the relative offset related to the base slice.
20412079
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
20422080
sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, depth + 1);
20432081
}
20442082

20452083
bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
20462084
Level lvl) {
2047-
Value c1 = C_IDX(1), c2 = C_IDX(2);
2048-
20492085
if (depFullyReduced(tid, lvl)) {
20502086
// Do not need to prepare for slice driven loop on dense level after it is
20512087
// fully reduced.
@@ -2054,14 +2090,12 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
20542090
// If constraints on the tensor is fully resolved. We do not need to
20552091
// generates slice begin any more, instead we fall back to TACO-based
20562092
// algorithm to (co)iterates over the slice.
2057-
Value pLoPtr =
2058-
loadSlicePosPtr(builder, loc, slicePosBuffer[tid][lvl].back());
2059-
pLoPtr = ADDI(pLoPtr, c2);
2060-
Value pHiPtr = ADDI(pLoPtr, c1);
2093+
Value sPosBuf = slicePosBuffer[tid][lvl].back();
2094+
Value tupleIdx = loadSlicePosPtr(builder, loc, sPosBuf);
20612095
posits[tid][lvl] =
2062-
genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), pLoPtr);
2096+
loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kLo);
20632097
highs[tid][lvl] =
2064-
genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), pHiPtr);
2098+
loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kHi);
20652099
return true;
20662100
}
20672101

@@ -2090,8 +2124,7 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
20902124
// The buffer can be reused, and the size is loop invariant: it only
20912125
// depends on the iteration graph's toposort.
20922126
builder.setInsertionPointAfter(localInsertPos);
2093-
Value bufSize = C_IDX(1);
2094-
Value c2 = C_IDX(2);
2127+
Value tupleCnt = C_IDX(1);
20952128
// Accumlates the size required to cache the pLo for the slice.
20962129
// E.g., if we want to cache the pIdx for slice<d0xd1xf64> on the second
20972130
// level. We at most need to a memref<d0xindex>.
@@ -2108,16 +2141,10 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
21082141
assert(!sliceMeta[tid][curLevel - 1].empty());
21092142
auto [sz, stride] = sliceMeta[tid][curLevel - 1].back();
21102143
assert(stride == 1 && "Not yet implemented");
2111-
bufSize = MULI(bufSize, sz);
2144+
tupleCnt = MULI(tupleCnt, sz);
21122145
}
2113-
// For a triple of [pLo, pHi, pPtr]. Note that we can not compress pHi
2114-
// because slice creates segments in the index buffer so that the pHi for
2115-
// the current level is no longer the pLo for the next level.
2116-
bufSize = MULI(bufSize, C_IDX(kSliceIterWidth));
2117-
// Additional two metadata {memSize, idx} at head.
2118-
bufSize = ADDI(bufSize, c2);
21192146
for (Value &cache : slicePosBuffer[tid][lvl])
2120-
cache = genAlloca(builder, loc, bufSize, builder.getIndexType());
2147+
cache = allocSlicePosBuf(builder, loc, tupleCnt);
21212148
}
21222149

21232150
if (sliceInfo.isInitialTensor() ||
@@ -2147,7 +2174,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
21472174
llvm_unreachable("TODO");
21482175

21492176
// else generate code to compute next non empty slice.
2150-
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
2177+
Value c0 = C_IDX(0), c1 = C_IDX(1);
21512178

21522179
SliceInfo &info = sliceStack[tid].back();
21532180
assert(info.slicedOnLvl == lvl);
@@ -2194,24 +2221,24 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
21942221
// offset = minCrd - size + 1;
21952222
// }
21962223
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
2197-
reduc[2] = absOffset; // restore value.
2198-
Value pSt = c2; // pointer starting index
2199-
Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
2200-
reduc[0] = lvlSizes[tid][lvl]; // next min coord
2201-
reduc[1] = constantI1(builder, loc, false); // isNonEmpty
2224+
reduc[2] = absOffset; // restore value.
2225+
Value mSz = loadSlicePosTupleNum(builder, loc, sPtrBuf); // memSize
2226+
reduc[0] = lvlSizes[tid][lvl]; // next min coord
2227+
reduc[1] = constantI1(builder, loc, false); // isNonEmpty
22022228
auto loopArgs = static_cast<ValueRange>(reduc).drop_back();
22032229
auto forOp = scf::buildLoopNest(
2204-
builder, loc, pSt, mSz, C_IDX(kSliceIterWidth), loopArgs,
2230+
builder, loc, c0, mSz, c1, loopArgs,
22052231
[this, tid, lvl, c1, sPtrBuf,
22062232
&info](OpBuilder &builder, Location loc, ValueRange ivs,
22072233
ValueRange iterArgs) -> scf::ValueVector {
22082234
Value curMinCrd = iterArgs[0];
22092235
Value isNonEmpty = iterArgs[1];
22102236

22112237
Type idxTp = builder.getIndexType();
2212-
Value pLo = genIndexLoad(builder, loc, sPtrBuf, ivs.front());
2213-
Value pHi =
2214-
genIndexLoad(builder, loc, sPtrBuf, ADDI(ivs.front(), c1));
2238+
Value pLo = loadSlicePos(builder, loc, sPtrBuf, ivs.front(),
2239+
SlicePosKind::kLo);
2240+
Value pHi = loadSlicePos(builder, loc, sPtrBuf, ivs.front(),
2241+
SlicePosKind::kHi);
22152242
//
22162243
// if (pLo < pHi) // Only loads when inbound.
22172244
// coord = load[pLo]
@@ -2235,8 +2262,8 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
22352262
&ifEqual.getThenRegion().front());
22362263
Value newPlo = ADDI(pLo, c1);
22372264
// Updates the cache.
2238-
builder.create<memref::StoreOp>(loc, newPlo, sPtrBuf,
2239-
ivs.front());
2265+
updateSlicePos(builder, loc, sPtrBuf, newPlo, ivs.front(),
2266+
SlicePosKind::kLo);
22402267
YIELD(newPlo);
22412268
}
22422269
/* else coord != minCrd */ {

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@ class SparsificationAndBufferizationPass
150150
pm.addPass(
151151
createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric));
152152
pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
153+
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
153154
if (vectorLength > 0) {
154-
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
155155
pm.addPass(createSparseVectorizationPass(
156156
vectorLength, enableVLAVectorization, enableSIMDIndex32));
157157
}

0 commit comments

Comments
 (0)