Skip to content

[mlir][sparse] optimize memory loads to SSA values when generating sp… #74787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 42 additions & 100 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,40 +147,30 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,

// Helper functions that load/store into the position buffer for slice-driven
// loops.
// The sliced pointer buffer is orgnized as:
// [size, curPtr] (two metadata) + [[pLo0, pLo1, pLo2, ...],
// [pHi0, pHi1, pHi2, ...],
// [pNx0, pNx1, pNx2, ...]]
// The sliced pointer buffer is organized as:
// [[pLo0, pLo1, pLo2, ...],
// [pHi0, pHi1, pHi2, ...],
// [pNx0, pNx1, pNx2, ...]]
static Value allocSlicePosBuf(OpBuilder &builder, Location loc,
Value tupleCnt) {
Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
// Additional two metadata {memSize, idx} at head.
bufSz = ADDI(bufSz, C_IDX(2));
return genAlloca(builder, loc, bufSz, builder.getIndexType());
}
// TODO: We should use SSA value for it.
// Gets and sets metadata.
static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) {
return genIndexLoad(builder, loc, sPosBuf, C_IDX(1));
}
static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
Value pPtr) {
builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
}

// Gets and sets position values for slice-driven loops.
enum class SlicePosKind { kLo, kHi, kNext };
static Value getSlicePosIdx(OpBuilder &builder, Location loc, Value posBuf,
Value tupleIdx, SlicePosKind posKind) {
Value dim = builder.create<memref::DimOp>(loc, posBuf, C_IDX(0));
Value tupleCnt = DIVUI(SUBI(dim, C_IDX(2)), C_IDX(kSliceIterWidth));
Value tupleCnt = DIVUI(dim, C_IDX(kSliceIterWidth));
switch (posKind) {
case SlicePosKind::kLo:
return ADDI(tupleIdx, C_IDX(2));
return tupleIdx;
case SlicePosKind::kHi:
return ADDI(tupleIdx, ADDI(tupleCnt, C_IDX(2)));
return ADDI(tupleIdx, tupleCnt);
case SlicePosKind::kNext:
return ADDI(tupleIdx, ADDI(tupleCnt, ADDI(tupleCnt, C_IDX(2))));
return ADDI(tupleIdx, MULI(tupleCnt, C_IDX(2)));
}
llvm_unreachable("unexpected kind");
}
Expand Down Expand Up @@ -344,6 +334,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
this->dependentLvlMap.assign(
numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
this->slicePosBuffer.assign(numTensors, std::vector<std::vector<Value>>());
this->sliceTupleNxStartIdx.assign(numTensors, std::vector<Value>());
this->sliceTupleFwdCnt.assign(numTensors, std::vector<Value>());
this->trivialSlice.assign(numTensors, std::vector<bool>());
this->sliceMeta.assign(
numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
this->sliceStack.assign(numTensors, std::vector<SliceInfo>());
Expand Down Expand Up @@ -394,6 +387,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
dependentLvlMap[tid].assign(
lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
slicePosBuffer[tid].assign(lvlRank, std::vector<Value>());
sliceTupleNxStartIdx[tid].assign(lvlRank, Value());
sliceTupleFwdCnt[tid].assign(lvlRank, Value());
trivialSlice[tid].assign(lvlRank, false);
sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
sliceStack[tid].emplace_back(/*minCrd=*/Value(),
/*offset=*/Value(), /*isNonEmpty*/ Value(),
Expand Down Expand Up @@ -806,6 +802,7 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
assert(ivs.size() == 1);
// Coord is the relative offset related to its parents.
assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement");
sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]);
// Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
Value posit = ivs[0];
Value crdBuf = coordinatesBuffers[tid][lvl];
Expand Down Expand Up @@ -1324,6 +1321,12 @@ void LoopEmitter::enterTensorsAtDenseLvls(
} else {
posits[tid][lvl] =
genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
Value fwdCnt = lvl == 0 || trivialSlice[tid][lvl]
? C_IDX(0)
: sliceTupleFwdCnt[tid][lvl - 1];
Value sz = sliceMeta[tid][lvl].back().first;
Value mul = MULI(fwdCnt, sz);
sliceTupleFwdCnt[tid][lvl] = ADDI(mul, iv);
}
levelReducedDep[tid][lvl]++;
} else {
Expand Down Expand Up @@ -1357,13 +1360,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
assert(isDenseLT(lvlTypes[tid][lvl]));
assert(*info.slicedOnLvl == lvl);
(void)reduced;
// Resets slices pointers as the resolved slices are invalidated after we
// moves forward to the next slice.
invalidateSliceIterIdx(rewriter, loc, tid, lvl);
info.minCrd = info.offset = info.isNonEmpty = Value();
} else {
forwardsReducedSliceLevelTreeIt(rewriter, loc, tid, lvl,
constantIndex(rewriter, loc, 1));
}
levelReducedDep[tid][lvl]--;
}
Expand Down Expand Up @@ -1443,54 +1440,6 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
}
}

void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
Location loc, TensorId tid,
Level rootLvl, Value fcnt) {

auto stt = getSparseTensorType(tensors[tid]);

// Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
// sparse levels (but not resolved). Since we forward an iterator at higher
// level of the tree, the subtree need to be pruned.
Level leafLvl = rootLvl + 1;
while (leafLvl < stt.getLvlRank() && depFullyReduced(tid, leafLvl) &&
!stt.isDenseLvl(leafLvl)) {
leafLvl++;
}

Level curLvl = rootLvl + 1;
Value nxPosPtr = nullptr;
if (curLvl < leafLvl) {
assert(!isDenseLT(lvlTypes[tid][curLvl]));
// The first compressed level, setting up the position pointer for it.
Value sPosBuf = slicePosBuffer[tid][curLvl].back();
// One step forwards in the parent level result in forwarding one `segment`
// in the child sparse level.
Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr
Value cPosPtr = ADDI(fcnt, pPosPtr); // current ptr
updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr);
// Loads the position pointer start for next level.
nxPosPtr =
loadSlicePos(builder, loc, sPosBuf, cPosPtr, SlicePosKind::kNext);
curLvl++;
}

// TODO: This is not always needed, but we did it unconditionally for now for
// simplicity.
// It is only needed when `curLvl` is forwarded without traversing its child
// level (e.g., the level is in a conjunctive lattices and got pruned), such
// that the position pointer is not forwarded inside the loop.
for (; curLvl < leafLvl; curLvl++) {
assert(nxPosPtr);
if (!isDenseLT(lvlTypes[tid][curLvl])) {
Value sPosBuf = slicePosBuffer[tid][curLvl].back();
updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr);
nxPosPtr =
loadSlicePos(builder, loc, sPosBuf, nxPosPtr, SlicePosKind::kNext);
}
}
}

void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
MutableArrayRef<Value> reduc) {
const LoopInfo &loopInfo = loopStack.back();
Expand Down Expand Up @@ -1540,13 +1489,6 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
forwarded = CMPI(eq, coords[tid][lvl], iv);
operands.push_back(SELECT(forwarded, nxPos, pos));
}
{
OpBuilder::InsertionGuard guard(builder);
auto ifOp = builder.create<scf::IfOp>(loc, TypeRange{}, forwarded,
/*else=*/false);
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
forwardsReducedSliceLevelTreeIt(builder, loc, tid, lvl, one);
}
// The coordinate is invalid now.
coords[tid][lvl] = nullptr;

Expand Down Expand Up @@ -1916,8 +1858,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
ADDI(posits[tid][lvl - 1], c1));
}
// Fills out pIdxBuffer[tid][lvl][0] with [0, pLo, pHi]
updateSlicePosPtr(builder, loc, sPtrBuf, c0);
// Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
// Slice over a resolved parent, we only need one pair of pos hi and lo to
Expand Down Expand Up @@ -2056,8 +1997,6 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
Value isNonEmpty = result[0];
Value minCrd = result[1];
// Two metadata [memSize, idx].
// TODO: Can use an SSA value for these two metadata
updateSlicePosPtr(builder, loc, sPtrBuf, c0);
// FIXME: we need the relative offset related to the base slice.
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, result[2], lvl,
Expand All @@ -2066,16 +2005,30 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,

bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
Level lvl) {
Value curLvlIdx = C_IDX(0);
if (depFullyReduced(tid, lvl)) {
// Do not need to prepare for slice driven loop on dense level after it is
// fully reduced.
if (lvl == 0 || trivialSlice[tid][lvl]) {
sliceTupleNxStartIdx[tid][lvl] = C_IDX(0);
} else {
if (isDenseLT(lvlTypes[tid][lvl])) {
sliceTupleNxStartIdx[tid][lvl] = sliceTupleNxStartIdx[tid][lvl - 1];
} else {
assert(isCompressedLT(lvlTypes[tid][lvl]));
curLvlIdx = ADDI(sliceTupleNxStartIdx[tid][lvl - 1],
sliceTupleFwdCnt[0][lvl - 1]);
sliceTupleNxStartIdx[tid][lvl] =
loadSlicePos(builder, loc, slicePosBuffer[tid][lvl].back(),
curLvlIdx, SlicePosKind::kNext);
}
}
if (isDenseLT(lvlTypes[tid][lvl]))
return true;

Value sPosBuf = slicePosBuffer[tid][lvl].back();
// If constraints on the tensor is fully resolved. We do not need to
// generates slice begin any more, instead we fall back to TACO-based
// algorithm to (co)iterates over the slice.
Value sPosBuf = slicePosBuffer[tid][lvl].back();
Value tupleIdx = loadSlicePosPtr(builder, loc, sPosBuf);
Value tupleIdx = curLvlIdx;
posits[tid][lvl] =
loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kLo);
highs[tid][lvl] =
Expand Down Expand Up @@ -2134,23 +2087,16 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
if (sliceInfo.isInitialTensor() ||
(lvl >= 1 && lvlFullyResolved(tid, lvl - 1))) {
// First level or previous level has been full resolved.
trivialSlice[tid][lvl] = true;
genResolvedSliceBegin(builder, loc, tid, lvl);
} else {
// The previous level has not been full resolved.
trivialSlice[tid][lvl] = false;
genUnResolvedSliceBegin(builder, loc, tid, lvl);
}
return false;
}

void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
for (unsigned i = 0; i <= lvl; i++) {
if (!isDenseLT(lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty()) {
updateSlicePosPtr(builder, loc, slicePosBuffer[tid][i].back(), C_IDX(0));
}
}
}

std::tuple<Value, Value, Value>
LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
TensorId tid, Level lvl) {
Expand All @@ -2175,10 +2121,6 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
// isNonEmpty = false;
//
Value absOffset = info.offset;
// Resets slices pointers as the resolved slices are invalidated after we
// moves forward to the next slice.
invalidateSliceIterIdx(builder, loc, tid, lvl);

SmallVector<Value, 3> reduc = {info.minCrd, info.isNonEmpty, absOffset};
Value sPtrBuf = slicePosBuffer[tid][lvl][info.depth - 1];
Value fastPathP = CMPI(ugt, info.minCrd, absOffset);
Expand Down
13 changes: 3 additions & 10 deletions mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,6 @@ class LoopEmitter {
return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
}

/// Forwards the (conceptual) "tree iterator" when iterating over a fully
/// reduced slice created by index-reduction.
void forwardsReducedSliceLevelTreeIt(OpBuilder &builder, Location loc,
TensorId tid, Level lvl, Value fcnt);

/// Prepares loop for iterating over `tensor[lvl]`, under the assumption
/// that `tensor[0...lvl-1]` loops have already been set up.
void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
Expand Down Expand Up @@ -610,11 +605,6 @@ class LoopEmitter {
void genUnResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
Level lvl);

/// Invalidates the index kept in slice postion buffers (by setting it to
/// zero).
/// TODO: We should instead use an SSA value for the index.
void invalidateSliceIterIdx(OpBuilder &builder, Location loc, TensorId tid,
Level lvl);
/// Generates code to get the first non-empty slice of tid on lvl.
/// return true if has already been resolved.
bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
Expand Down Expand Up @@ -683,6 +673,9 @@ class LoopEmitter {
// But they always starts with the first pidx pointing to coord > slice.offset
// to avoid iteration from the beginning.
std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
std::vector<std::vector<Value>> sliceTupleNxStartIdx;
std::vector<std::vector<Value>> sliceTupleFwdCnt;
std::vector<std::vector<bool>> trivialSlice;

// The (size, stride) for each conceptual slice used for index reduction
// loops.
Expand Down
Loading