Skip to content

[mlir][sparse] refactoring: using util functions to query the index to load from position array for slice-driven loop. #73986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 97 additions & 70 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,23 +148,60 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
// Helper functions that load/store into the position buffer for slice-driven
// loops.
// The sliced pointer buffer is orgnized as:
// [size, curPtr] (two metadata) + [[pLo, pHi, pNext], ...] (list of tuples)
// [size, curPtr] (two metadata) + [[pLo0, pLo1, pLo2, ...],
// [pHi0, pHi1, pHi2, ...],
// [pNx0, pNx1, pNx2, ...]]
static Value allocSlicePosBuf(OpBuilder &builder, Location loc,
Value tupleCnt) {
Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
// Additional two metadata {memSize, idx} at head.
bufSz = ADDI(bufSz, C_IDX(2));
return genAlloca(builder, loc, bufSz, builder.getIndexType());
}
// TODO: We should use SSA value for it.
// Gets and sets metadata.
static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) {
// Load curPtr.
// TODO: We should use SSA value for it.
return genIndexLoad(builder, loc, sPosBuf, C_IDX(1));
}
static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
Value pPtr) {
// Set curPtr.
// TODO: We should use SSA value for it.
builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
}
static Value loadSliceNextPosPtrStart(OpBuilder &builder, Location loc,
Value sPosBuf, Value tupleIdx) {
// load the pNext in the current tuple specified by `tupleIdx`.
// 4 = 2 (two metadata) + 2 (pNext == tuple[2])
return genIndexLoad(builder, loc, sPosBuf, ADDI(tupleIdx, C_IDX(4)));
static Value loadSlicePosTupleNum(OpBuilder &builder, Location loc,
Value sPosBuf) {
return genIndexLoad(builder, loc, sPosBuf, C_IDX(0));
}
static void updateSlicePosTupleNum(OpBuilder &builder, Location loc, Value num,
Value sPosBuf) {
builder.create<memref::StoreOp>(loc, num, sPosBuf, C_IDX(0));
}

// Gets and sets position values for slice-driven loops.
enum class SlicePosKind { kLo, kHi, kNext };
static Value getSlicePosIdx(OpBuilder &builder, Location loc, Value posBuf,
Value tupleIdx, SlicePosKind posKind) {
Value dim = builder.create<memref::DimOp>(loc, posBuf, C_IDX(0));
Value tupleCnt = DIVUI(SUBI(dim, C_IDX(2)), C_IDX(kSliceIterWidth));
switch (posKind) {
case SlicePosKind::kLo:
return ADDI(tupleIdx, C_IDX(2));
case SlicePosKind::kHi:
return ADDI(tupleIdx, ADDI(tupleCnt, C_IDX(2)));
case SlicePosKind::kNext:
return ADDI(tupleIdx, ADDI(tupleCnt, ADDI(tupleCnt, C_IDX(2))));
}
llvm_unreachable("unexpected kind");
}
static Value loadSlicePos(OpBuilder &builder, Location loc, Value sPosBuf,
Value tupleIdx, SlicePosKind posKind) {
return genIndexLoad(builder, loc, sPosBuf,
getSlicePosIdx(builder, loc, sPosBuf, tupleIdx, posKind));
}
static void updateSlicePos(OpBuilder &builder, Location loc, Value sPosBuf,
Value pos, Value tupleIdx, SlicePosKind posKind) {
builder.create<memref::StoreOp>(
loc, pos, sPosBuf,
getSlicePosIdx(builder, loc, sPosBuf, tupleIdx, posKind));
}

std::pair<Value, Value>
Expand Down Expand Up @@ -1446,13 +1483,13 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
// The first compressed level, setting up the position pointer for it.
Value sPosBuf = slicePosBuffer[tid][curLvl].back();
// One step forwards in the parent level result in forwarding one `segment`
// (kSliceIterWidth) in the child sparse level.
Value fPosPtr = MULI(fcnt, C_IDX(kSliceIterWidth)); // forward ptr
// in the child sparse level.
Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr
Value cPosPtr = ADDI(fPosPtr, pPosPtr); // current ptr
Value cPosPtr = ADDI(fcnt, pPosPtr); // current ptr
updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr);
// Loads the position pointer start for next level.
nxPosPtr = loadSliceNextPosPtrStart(builder, loc, sPosBuf, cPosPtr);
nxPosPtr =
loadSlicePos(builder, loc, sPosBuf, cPosPtr, SlicePosKind::kNext);
curLvl++;
}

Expand All @@ -1464,10 +1501,10 @@ void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
for (; curLvl < leafLvl; curLvl++) {
assert(nxPosPtr);
if (!isDenseLT(lvlTypes[tid][curLvl])) {
nxPosPtr = MULI(nxPosPtr, C_IDX(kSliceIterWidth));
Value sPosBuf = slicePosBuffer[tid][curLvl].back();
updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr);
nxPosPtr = loadSliceNextPosPtrStart(builder, loc, sPosBuf, nxPosPtr);
nxPosPtr =
loadSlicePos(builder, loc, sPosBuf, nxPosPtr, SlicePosKind::kNext);
}
}
}
Expand Down Expand Up @@ -1737,7 +1774,7 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
std::optional<std::pair<TensorId, Level>> firstResLvl, ValueRange userReduc,
LoopBodyBuilder bodyBuilder) {

Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
Value c0 = C_IDX(0), c1 = C_IDX(1);
Value pos = c0;
OpBuilder::InsertPoint ip;
SmallVector<Value> innerArgs(userReduc.begin(), userReduc.end());
Expand Down Expand Up @@ -1770,20 +1807,22 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(
unsigned depth = frontSlice.depth - 1;
Value offset = frontSlice.offset;
Value sPtrBuf = slicePosBuffer[tid][firstLvl][depth];
Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
Value mSz = loadSlicePosTupleNum(builder, loc, sPtrBuf);
outerMost = builder.create<scf::ForOp>(
loc, c2, mSz, C_IDX(kSliceIterWidth), innerArgs,
[this, c1, c2, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
loc, c0, mSz, c1, innerArgs,
[this, tid, firstLvl, offset, sPtrBuf, &ip, &pos,
&innerArgs](OpBuilder &builder, Location loc, Value iv,
ValueRange iterArgs) {
// generate traversal for each level.
Value loopLo = genIndexLoad(builder, loc, sPtrBuf, iv);
Value loopHi = genIndexLoad(builder, loc, sPtrBuf, ADDI(iv, c1));
Value loopLo =
loadSlicePos(builder, loc, sPtrBuf, iv, SlicePosKind::kLo);
Value loopHi =
loadSlicePos(builder, loc, sPtrBuf, iv, SlicePosKind::kHi);
// We need to remember the starting index for next level's
// position, because slice-driven loop breaks the level into
// non-consecutive segments.
builder.create<memref::StoreOp>(loc, iterArgs.back(), sPtrBuf,
ADDI(iv, c2).getResult());
updateSlicePos(builder, loc, sPtrBuf, iterArgs.back(), iv,
SlicePosKind::kNext);

auto [size, stride] = sliceMeta[tid][firstLvl].back();
assert(stride == 1 && "Not yet implemented");
Expand Down Expand Up @@ -1874,8 +1913,7 @@ ValueRange LoopEmitter::genUnResolvedSliceTreeTraverse(

void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2), c3 = C_IDX(3),
c4 = C_IDX(4);
Value c0 = C_IDX(0), c1 = C_IDX(1);
if (isDenseLT(lvlTypes[tid][lvl])) {
// Dense slice begin is trivial.
sliceStack[tid].emplace_back(/*minCoord=*/c0, /*offset=*/c0,
Expand All @@ -1897,10 +1935,10 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
ADDI(posits[tid][lvl - 1], c1));
}
// Fills out pIdxBuffer[tid][lvl][0] with [/*memSize =*/4, 0, pLo, pHi]
builder.create<memref::StoreOp>(loc, c4, sPtrBuf, c0); // memSize = 4
builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1); // index = 0
builder.create<memref::StoreOp>(loc, pLo, sPtrBuf, c2); // pLo
builder.create<memref::StoreOp>(loc, pHi, sPtrBuf, c3); // pHi
updateSlicePosTupleNum(builder, loc, c1, sPtrBuf);
updateSlicePosPtr(builder, loc, sPtrBuf, c0);
updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);

// This is an non empty tensor if pLo < pHi.
Value isNonEmpty = CMPI(ult, pLo, pHi);
Expand Down Expand Up @@ -1939,7 +1977,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
// }
void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
Value c0 = C_IDX(0), c1 = C_IDX(1);
unsigned depth = levelReducedDep[tid][lvl];
// The remaining slice size after reduction.
Value remSz = sliceMeta[tid][lvl][depth + 1].first;
Expand Down Expand Up @@ -1984,7 +2022,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
SmallVector<Value, 3> reduc = {
constantI1(builder, loc, false), // isNonEmpty
lvlSizes[tid][lvl], // minCoord
c2, // memSize
c0, // memSize
};

ValueRange result = genUnResolvedSliceTreeTraverse(
Expand All @@ -1993,7 +2031,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc) {
Value &nonEmpty = reduc[0];
Value &minCrd = reduc[1];
Value &curMemSz = reduc[2];
Value &curTupleCnt = reduc[2];

Value pHi = ADDI(iv, c1);
Value sPLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], iv);
Expand Down Expand Up @@ -2025,28 +2063,26 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
YIELD(minCrd);
}
minCrd = ifNonEmpty.getResult(0);
builder.create<memref::StoreOp>(loc, sPLo, sPtrBuf, curMemSz);
Value nxtMemSize = ADDI(curMemSz, c1);
builder.create<memref::StoreOp>(loc, sPHi, sPtrBuf, nxtMemSize);
// curMemSize += kSliceIterWidth
curMemSz = ADDI(curMemSz, C_IDX(kSliceIterWidth));
updateSlicePos(builder, loc, sPtrBuf, sPLo, curTupleCnt,
SlicePosKind::kLo);
updateSlicePos(builder, loc, sPtrBuf, sPHi, curTupleCnt,
SlicePosKind::kHi);
curTupleCnt = ADDI(curTupleCnt, C_IDX(1));
});

Value isNonEmpty = result[0];
Value minCrd = result[1];
// Two metadata [memSize, idx].
// TODO: Can use an SSA value for these two metadata
builder.create<memref::StoreOp>(loc, result[2], sPtrBuf, c0);
builder.create<memref::StoreOp>(loc, c0, sPtrBuf, c1);
updateSlicePosTupleNum(builder, loc, result[2], sPtrBuf);
updateSlicePosPtr(builder, loc, sPtrBuf, c0);
// FIXME: we need the relative offset related to the base slice.
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, lvl, depth + 1);
}

bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
Level lvl) {
Value c1 = C_IDX(1), c2 = C_IDX(2);

if (depFullyReduced(tid, lvl)) {
// Do not need to prepare for slice driven loop on dense level after it is
// fully reduced.
Expand All @@ -2055,14 +2091,12 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
// If constraints on the tensor is fully resolved. We do not need to
// generates slice begin any more, instead we fall back to TACO-based
// algorithm to (co)iterates over the slice.
Value pLoPtr =
loadSlicePosPtr(builder, loc, slicePosBuffer[tid][lvl].back());
pLoPtr = ADDI(pLoPtr, c2);
Value pHiPtr = ADDI(pLoPtr, c1);
Value sPosBuf = slicePosBuffer[tid][lvl].back();
Value tupleIdx = loadSlicePosPtr(builder, loc, sPosBuf);
posits[tid][lvl] =
genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), pLoPtr);
loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kLo);
highs[tid][lvl] =
genIndexLoad(builder, loc, slicePosBuffer[tid][lvl].back(), pHiPtr);
loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kHi);
return true;
}

Expand Down Expand Up @@ -2091,8 +2125,7 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
// The buffer can be reused, and the size is loop invariant: it only
// depends on the iteration graph's toposort.
builder.setInsertionPointAfter(localInsertPos);
Value bufSize = C_IDX(1);
Value c2 = C_IDX(2);
Value tupleCnt = C_IDX(1);
// Accumlates the size required to cache the pLo for the slice.
// E.g., if we want to cache the pIdx for slice<d0xd1xf64> on the second
// level. We at most need to a memref<d0xindex>.
Expand All @@ -2109,16 +2142,10 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
assert(!sliceMeta[tid][curLevel - 1].empty());
auto [sz, stride] = sliceMeta[tid][curLevel - 1].back();
assert(stride == 1 && "Not yet implemented");
bufSize = MULI(bufSize, sz);
tupleCnt = MULI(tupleCnt, sz);
}
// For a triple of [pLo, pHi, pPtr]. Note that we can not compress pHi
// because slice creates segments in the index buffer so that the pHi for
// the current level is no longer the pLo for the next level.
bufSize = MULI(bufSize, C_IDX(kSliceIterWidth));
// Additional two metadata {memSize, idx} at head.
bufSize = ADDI(bufSize, c2);
for (Value &cache : slicePosBuffer[tid][lvl])
cache = genAlloca(builder, loc, bufSize, builder.getIndexType());
cache = allocSlicePosBuf(builder, loc, tupleCnt);
}

if (sliceInfo.isInitialTensor() ||
Expand Down Expand Up @@ -2148,7 +2175,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
llvm_unreachable("TODO");

// else generate code to compute next non empty slice.
Value c0 = C_IDX(0), c1 = C_IDX(1), c2 = C_IDX(2);
Value c0 = C_IDX(0), c1 = C_IDX(1);

SliceInfo &info = sliceStack[tid].back();
assert(info.slicedOnLvl == lvl);
Expand Down Expand Up @@ -2195,24 +2222,24 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
// offset = minCrd - size + 1;
// }
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
reduc[2] = absOffset; // restore value.
Value pSt = c2; // pointer starting index
Value mSz = genIndexLoad(builder, loc, sPtrBuf, c0); // memSize
reduc[0] = lvlSizes[tid][lvl]; // next min coord
reduc[1] = constantI1(builder, loc, false); // isNonEmpty
reduc[2] = absOffset; // restore value.
Value mSz = loadSlicePosTupleNum(builder, loc, sPtrBuf); // memSize
reduc[0] = lvlSizes[tid][lvl]; // next min coord
reduc[1] = constantI1(builder, loc, false); // isNonEmpty
auto loopArgs = static_cast<ValueRange>(reduc).drop_back();
auto forOp = scf::buildLoopNest(
builder, loc, pSt, mSz, C_IDX(kSliceIterWidth), loopArgs,
builder, loc, c0, mSz, c1, loopArgs,
[this, tid, lvl, c1, sPtrBuf,
&info](OpBuilder &builder, Location loc, ValueRange ivs,
ValueRange iterArgs) -> scf::ValueVector {
Value curMinCrd = iterArgs[0];
Value isNonEmpty = iterArgs[1];

Type idxTp = builder.getIndexType();
Value pLo = genIndexLoad(builder, loc, sPtrBuf, ivs.front());
Value pHi =
genIndexLoad(builder, loc, sPtrBuf, ADDI(ivs.front(), c1));
Value pLo = loadSlicePos(builder, loc, sPtrBuf, ivs.front(),
SlicePosKind::kLo);
Value pHi = loadSlicePos(builder, loc, sPtrBuf, ivs.front(),
SlicePosKind::kHi);
//
// if (pLo < pHi) // Only loads when inbound.
// coord = load[pLo]
Expand All @@ -2236,8 +2263,8 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
&ifEqual.getThenRegion().front());
Value newPlo = ADDI(pLo, c1);
// Updates the cache.
builder.create<memref::StoreOp>(loc, newPlo, sPtrBuf,
ivs.front());
updateSlicePos(builder, loc, sPtrBuf, newPlo, ivs.front(),
SlicePosKind::kLo);
YIELD(newPlo);
}
/* else coord != minCrd */ {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ class SparsificationAndBufferizationPass
pm.addPass(
createSparseReinterpretMapPass(ReinterpretMapScope::kExceptGeneric));
pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
if (vectorLength > 0) {
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addPass(createSparseVectorizationPass(
vectorLength, enableVLAVectorization, enableSIMDIndex32));
}
Expand Down
Loading