Skip to content

Commit baa192e

Browse files
authored
[mlir][sparse] optimize memory loads to SSA values when generating sp… (#74787)
…arse conv.
1 parent a539a09 commit baa192e

File tree

3 files changed

+224
-304
lines changed

3 files changed

+224
-304
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 42 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -147,40 +147,30 @@ static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
147147

148148
// Helper functions that load/store into the position buffer for slice-driven
149149
// loops.
150-
// The sliced pointer buffer is orgnized as:
151-
// [size, curPtr] (two metadata) + [[pLo0, pLo1, pLo2, ...],
152-
// [pHi0, pHi1, pHi2, ...],
153-
// [pNx0, pNx1, pNx2, ...]]
150+
// The sliced pointer buffer is organized as:
151+
// [[pLo0, pLo1, pLo2, ...],
152+
// [pHi0, pHi1, pHi2, ...],
153+
// [pNx0, pNx1, pNx2, ...]]
154154
static Value allocSlicePosBuf(OpBuilder &builder, Location loc,
155155
Value tupleCnt) {
156156
Value bufSz = MULI(tupleCnt, C_IDX(kSliceIterWidth));
157157
// Additional two metadata {memSize, idx} at head.
158-
bufSz = ADDI(bufSz, C_IDX(2));
159158
return genAlloca(builder, loc, bufSz, builder.getIndexType());
160159
}
161-
// TODO: We should use SSA value for it.
162-
// Gets and sets metadata.
163-
static Value loadSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf) {
164-
return genIndexLoad(builder, loc, sPosBuf, C_IDX(1));
165-
}
166-
static void updateSlicePosPtr(OpBuilder &builder, Location loc, Value sPosBuf,
167-
Value pPtr) {
168-
builder.create<memref::StoreOp>(loc, pPtr, sPosBuf, C_IDX(1));
169-
}
170160

171161
// Gets and sets position values for slice-driven loops.
172162
enum class SlicePosKind { kLo, kHi, kNext };
173163
static Value getSlicePosIdx(OpBuilder &builder, Location loc, Value posBuf,
174164
Value tupleIdx, SlicePosKind posKind) {
175165
Value dim = builder.create<memref::DimOp>(loc, posBuf, C_IDX(0));
176-
Value tupleCnt = DIVUI(SUBI(dim, C_IDX(2)), C_IDX(kSliceIterWidth));
166+
Value tupleCnt = DIVUI(dim, C_IDX(kSliceIterWidth));
177167
switch (posKind) {
178168
case SlicePosKind::kLo:
179-
return ADDI(tupleIdx, C_IDX(2));
169+
return tupleIdx;
180170
case SlicePosKind::kHi:
181-
return ADDI(tupleIdx, ADDI(tupleCnt, C_IDX(2)));
171+
return ADDI(tupleIdx, tupleCnt);
182172
case SlicePosKind::kNext:
183-
return ADDI(tupleIdx, ADDI(tupleCnt, ADDI(tupleCnt, C_IDX(2))));
173+
return ADDI(tupleIdx, MULI(tupleCnt, C_IDX(2)));
184174
}
185175
llvm_unreachable("unexpected kind");
186176
}
@@ -344,6 +334,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
344334
this->dependentLvlMap.assign(
345335
numTensors, std::vector<std::vector<std::pair<TensorLevel, unsigned>>>());
346336
this->slicePosBuffer.assign(numTensors, std::vector<std::vector<Value>>());
337+
this->sliceTupleNxStartIdx.assign(numTensors, std::vector<Value>());
338+
this->sliceTupleFwdCnt.assign(numTensors, std::vector<Value>());
339+
this->trivialSlice.assign(numTensors, std::vector<bool>());
347340
this->sliceMeta.assign(
348341
numTensors, std::vector<std::vector<std::pair<Value, unsigned>>>());
349342
this->sliceStack.assign(numTensors, std::vector<SliceInfo>());
@@ -394,6 +387,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
394387
dependentLvlMap[tid].assign(
395388
lvlRank, std::vector<std::pair<TensorLevel, unsigned>>());
396389
slicePosBuffer[tid].assign(lvlRank, std::vector<Value>());
390+
sliceTupleNxStartIdx[tid].assign(lvlRank, Value());
391+
sliceTupleFwdCnt[tid].assign(lvlRank, Value());
392+
trivialSlice[tid].assign(lvlRank, false);
397393
sliceMeta[tid].assign(lvlRank, std::vector<std::pair<Value, unsigned>>());
398394
sliceStack[tid].emplace_back(/*minCrd=*/Value(),
399395
/*offset=*/Value(), /*isNonEmpty*/ Value(),
@@ -806,6 +802,7 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
806802
assert(ivs.size() == 1);
807803
// Coord is the relative offset related to its parents.
808804
assert(sliceStack[tid].back().depth == 1 && "TODO: not yet implement");
805+
sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]);
809806
// Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
810807
Value posit = ivs[0];
811808
Value crdBuf = coordinatesBuffers[tid][lvl];
@@ -1324,6 +1321,12 @@ void LoopEmitter::enterTensorsAtDenseLvls(
13241321
} else {
13251322
posits[tid][lvl] =
13261323
genAddress(builder, loc, tid, lvl, ADDI(info.offset, iv));
1324+
Value fwdCnt = lvl == 0 || trivialSlice[tid][lvl]
1325+
? C_IDX(0)
1326+
: sliceTupleFwdCnt[tid][lvl - 1];
1327+
Value sz = sliceMeta[tid][lvl].back().first;
1328+
Value mul = MULI(fwdCnt, sz);
1329+
sliceTupleFwdCnt[tid][lvl] = ADDI(mul, iv);
13271330
}
13281331
levelReducedDep[tid][lvl]++;
13291332
} else {
@@ -1357,13 +1360,7 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
13571360
assert(isDenseLT(lvlTypes[tid][lvl]));
13581361
assert(*info.slicedOnLvl == lvl);
13591362
(void)reduced;
1360-
// Resets slices pointers as the resolved slices are invalidated after we
1361-
// moves forward to the next slice.
1362-
invalidateSliceIterIdx(rewriter, loc, tid, lvl);
13631363
info.minCrd = info.offset = info.isNonEmpty = Value();
1364-
} else {
1365-
forwardsReducedSliceLevelTreeIt(rewriter, loc, tid, lvl,
1366-
constantIndex(rewriter, loc, 1));
13671364
}
13681365
levelReducedDep[tid][lvl]--;
13691366
}
@@ -1443,54 +1440,6 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
14431440
}
14441441
}
14451442

1446-
void LoopEmitter::forwardsReducedSliceLevelTreeIt(OpBuilder &builder,
1447-
Location loc, TensorId tid,
1448-
Level rootLvl, Value fcnt) {
1449-
1450-
auto stt = getSparseTensorType(tensors[tid]);
1451-
1452-
// Finds a [Lvl, leafLvl) range, and all level in between are fully reduced
1453-
// sparse levels (but not resolved). Since we forward an iterator at higher
1454-
// level of the tree, the subtree need to be pruned.
1455-
Level leafLvl = rootLvl + 1;
1456-
while (leafLvl < stt.getLvlRank() && depFullyReduced(tid, leafLvl) &&
1457-
!stt.isDenseLvl(leafLvl)) {
1458-
leafLvl++;
1459-
}
1460-
1461-
Level curLvl = rootLvl + 1;
1462-
Value nxPosPtr = nullptr;
1463-
if (curLvl < leafLvl) {
1464-
assert(!isDenseLT(lvlTypes[tid][curLvl]));
1465-
// The first compressed level, setting up the position pointer for it.
1466-
Value sPosBuf = slicePosBuffer[tid][curLvl].back();
1467-
// One step forwards in the parent level result in forwarding one `segment`
1468-
// in the child sparse level.
1469-
Value pPosPtr = loadSlicePosPtr(builder, loc, sPosBuf); // previous ptr
1470-
Value cPosPtr = ADDI(fcnt, pPosPtr); // current ptr
1471-
updateSlicePosPtr(builder, loc, sPosBuf, cPosPtr);
1472-
// Loads the position pointer start for next level.
1473-
nxPosPtr =
1474-
loadSlicePos(builder, loc, sPosBuf, cPosPtr, SlicePosKind::kNext);
1475-
curLvl++;
1476-
}
1477-
1478-
// TODO: This is not always needed, but we did it unconditionally for now for
1479-
// simplicity.
1480-
// It is only needed when `curLvl` is forwarded without traversing its child
1481-
// level (e.g., the level is in a conjunctive lattices and got pruned), such
1482-
// that the position pointer is not forwarded inside the loop.
1483-
for (; curLvl < leafLvl; curLvl++) {
1484-
assert(nxPosPtr);
1485-
if (!isDenseLT(lvlTypes[tid][curLvl])) {
1486-
Value sPosBuf = slicePosBuffer[tid][curLvl].back();
1487-
updateSlicePosPtr(builder, loc, sPosBuf, nxPosPtr);
1488-
nxPosPtr =
1489-
loadSlicePos(builder, loc, sPosBuf, nxPosPtr, SlicePosKind::kNext);
1490-
}
1491-
}
1492-
}
1493-
14941443
void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
14951444
MutableArrayRef<Value> reduc) {
14961445
const LoopInfo &loopInfo = loopStack.back();
@@ -1540,13 +1489,6 @@ void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
15401489
forwarded = CMPI(eq, coords[tid][lvl], iv);
15411490
operands.push_back(SELECT(forwarded, nxPos, pos));
15421491
}
1543-
{
1544-
OpBuilder::InsertionGuard guard(builder);
1545-
auto ifOp = builder.create<scf::IfOp>(loc, TypeRange{}, forwarded,
1546-
/*else=*/false);
1547-
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1548-
forwardsReducedSliceLevelTreeIt(builder, loc, tid, lvl, one);
1549-
}
15501492
// The coordinate is invalid now.
15511493
coords[tid][lvl] = nullptr;
15521494

@@ -1916,8 +1858,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
19161858
pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
19171859
ADDI(posits[tid][lvl - 1], c1));
19181860
}
1919-
// Fills out pIdxBuffer[tid][lvl][0] with [0, pLo, pHi]
1920-
updateSlicePosPtr(builder, loc, sPtrBuf, c0);
1861+
// Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
19211862
updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
19221863
updateSlicePos(builder, loc, sPtrBuf, pHi, c0, SlicePosKind::kHi);
19231864
// Slice over a resolved parent, we only need one pair of pos hi and lo to
@@ -2056,8 +1997,6 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
20561997
Value isNonEmpty = result[0];
20571998
Value minCrd = result[1];
20581999
// Two metadata [memSize, idx].
2059-
// TODO: Can use an SSA value for these two metadata
2060-
updateSlicePosPtr(builder, loc, sPtrBuf, c0);
20612000
// FIXME: we need the relative offset related to the base slice.
20622001
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, remSz, isNonEmpty);
20632002
sliceStack[tid].emplace_back(minCrd, absOffset, isNonEmpty, result[2], lvl,
@@ -2066,16 +2005,30 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
20662005

20672006
bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
20682007
Level lvl) {
2008+
Value curLvlIdx = C_IDX(0);
20692009
if (depFullyReduced(tid, lvl)) {
2070-
// Do not need to prepare for slice driven loop on dense level after it is
2071-
// fully reduced.
2010+
if (lvl == 0 || trivialSlice[tid][lvl]) {
2011+
sliceTupleNxStartIdx[tid][lvl] = C_IDX(0);
2012+
} else {
2013+
if (isDenseLT(lvlTypes[tid][lvl])) {
2014+
sliceTupleNxStartIdx[tid][lvl] = sliceTupleNxStartIdx[tid][lvl - 1];
2015+
} else {
2016+
assert(isCompressedLT(lvlTypes[tid][lvl]));
2017+
curLvlIdx = ADDI(sliceTupleNxStartIdx[tid][lvl - 1],
2018+
sliceTupleFwdCnt[0][lvl - 1]);
2019+
sliceTupleNxStartIdx[tid][lvl] =
2020+
loadSlicePos(builder, loc, slicePosBuffer[tid][lvl].back(),
2021+
curLvlIdx, SlicePosKind::kNext);
2022+
}
2023+
}
20722024
if (isDenseLT(lvlTypes[tid][lvl]))
20732025
return true;
2026+
2027+
Value sPosBuf = slicePosBuffer[tid][lvl].back();
20742028
// If constraints on the tensor is fully resolved. We do not need to
20752029
// generates slice begin any more, instead we fall back to TACO-based
20762030
// algorithm to (co)iterates over the slice.
2077-
Value sPosBuf = slicePosBuffer[tid][lvl].back();
2078-
Value tupleIdx = loadSlicePosPtr(builder, loc, sPosBuf);
2031+
Value tupleIdx = curLvlIdx;
20792032
posits[tid][lvl] =
20802033
loadSlicePos(builder, loc, sPosBuf, tupleIdx, SlicePosKind::kLo);
20812034
highs[tid][lvl] =
@@ -2134,23 +2087,16 @@ bool LoopEmitter::genSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
21342087
if (sliceInfo.isInitialTensor() ||
21352088
(lvl >= 1 && lvlFullyResolved(tid, lvl - 1))) {
21362089
// First level or previous level has been full resolved.
2090+
trivialSlice[tid][lvl] = true;
21372091
genResolvedSliceBegin(builder, loc, tid, lvl);
21382092
} else {
21392093
// The previous level has not been full resolved.
2094+
trivialSlice[tid][lvl] = false;
21402095
genUnResolvedSliceBegin(builder, loc, tid, lvl);
21412096
}
21422097
return false;
21432098
}
21442099

2145-
void LoopEmitter::invalidateSliceIterIdx(OpBuilder &builder, Location loc,
2146-
TensorId tid, Level lvl) {
2147-
for (unsigned i = 0; i <= lvl; i++) {
2148-
if (!isDenseLT(lvlTypes[tid][i]) && !dependentLvlMap[tid][i].empty()) {
2149-
updateSlicePosPtr(builder, loc, slicePosBuffer[tid][i].back(), C_IDX(0));
2150-
}
2151-
}
2152-
}
2153-
21542100
std::tuple<Value, Value, Value>
21552101
LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
21562102
TensorId tid, Level lvl) {
@@ -2175,10 +2121,6 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
21752121
// isNonEmpty = false;
21762122
//
21772123
Value absOffset = info.offset;
2178-
// Resets slices pointers as the resolved slices are invalidated after we
2179-
// moves forward to the next slice.
2180-
invalidateSliceIterIdx(builder, loc, tid, lvl);
2181-
21822124
SmallVector<Value, 3> reduc = {info.minCrd, info.isNonEmpty, absOffset};
21832125
Value sPtrBuf = slicePosBuffer[tid][lvl][info.depth - 1];
21842126
Value fastPathP = CMPI(ugt, info.minCrd, absOffset);

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -453,11 +453,6 @@ class LoopEmitter {
453453
return tid < lvlTypes.size() && lvl < lvlTypes[tid].size();
454454
}
455455

456-
/// Forwards the (conceptual) "tree iterator" when iterating over a fully
457-
/// reduced slice created by index-reduction.
458-
void forwardsReducedSliceLevelTreeIt(OpBuilder &builder, Location loc,
459-
TensorId tid, Level lvl, Value fcnt);
460-
461456
/// Prepares loop for iterating over `tensor[lvl]`, under the assumption
462457
/// that `tensor[0...lvl-1]` loops have already been set up.
463458
void prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
@@ -610,11 +605,6 @@ class LoopEmitter {
610605
void genUnResolvedSliceBegin(OpBuilder &builder, Location loc, TensorId tid,
611606
Level lvl);
612607

613-
/// Invalidates the index kept in slice postion buffers (by setting it to
614-
/// zero).
615-
/// TODO: We should instead use an SSA value for the index.
616-
void invalidateSliceIterIdx(OpBuilder &builder, Location loc, TensorId tid,
617-
Level lvl);
618608
/// Generates code to get the first non-empty slice of tid on lvl.
619609
/// return true if has already been resolved.
620610
bool genSliceBegin(OpBuilder &builder, Location loc, TensorId tid, Level lvl);
@@ -683,6 +673,9 @@ class LoopEmitter {
683673
// But they always starts with the first pidx pointing to coord > slice.offset
684674
// to avoid iteration from the beginning.
685675
std::vector<std::vector<std::vector<Value>>> slicePosBuffer;
676+
std::vector<std::vector<Value>> sliceTupleNxStartIdx;
677+
std::vector<std::vector<Value>> sliceTupleFwdCnt;
678+
std::vector<std::vector<bool>> trivialSlice;
686679

687680
// The (size, stride) for each conceptual slice used for index reduction
688681
// loops.

0 commit comments

Comments
 (0)