Skip to content

Commit c6f97d5

Browse files
author
Peiming Liu
committed
[mlir][sparse] set up the skeleton for SparseTensorLevel abstraction.
1 parent 59f7f35 commit c6f97d5

File tree

5 files changed

+197
-64
lines changed

5 files changed

+197
-64
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
1919
Utils/IterationGraphSorter.cpp
2020
Utils/LoopEmitter.cpp
2121
Utils/SparseTensorDescriptor.cpp
22+
Utils/SparseTensorLevels.cpp
2223

2324
ADDITIONAL_HEADER_DIRS
2425
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 36 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,15 @@ static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
126126
// Generates a bool value for while loop condition that tries to iterate over a
127127
// fully reduced level with affine index expression.
128128
static Value genSparseReducedAffineCond(OpBuilder &builder, Location loc,
129-
Value crdBuf, Value crdHi, Value posit,
130-
Value posHi) {
129+
const SparseTensorLevel &level,
130+
Value crdHi, Value posit, Value posHi) {
131131
Value inBound = CMPI(ult, posit, posHi);
132132
auto ifOp =
133133
builder.create<scf::IfOp>(loc, builder.getI1Type(), inBound, true);
134134
// if (inbound)
135135
// yield coord < crdHi
136136
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
137-
Value crd = genIndexLoad(builder, loc, crdBuf, posit);
137+
Value crd = level.peekCrdAt(builder, loc, posit);
138138
YIELD(CMPI(ult, crd, crdHi));
139139
// else
140140
// yield false
@@ -244,13 +244,12 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid,
244244
Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
245245
TensorId tid, Level lvl, Value pLo,
246246
Value pHi) {
247-
const auto coordinates = coordinatesBuffers[tid][lvl];
248-
const auto sameCrd = genIndexLoad(builder, loc, coordinates, pLo);
247+
SparseTensorLevel &level = *lvls[tid][lvl];
248+
const Value sameCrd = level.peekCrdAt(builder, loc, pLo);
249249
auto whileOp = builder.create<scf::WhileOp>(
250250
loc, builder.getIndexType(), pLo,
251251
/*beforeBuilder=*/
252-
[pHi, coordinates, sameCrd](OpBuilder &builder, Location loc,
253-
ValueRange ivs) {
252+
[pHi, &level, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) {
254253
const auto pos = ivs[0];
255254
Value inBound = builder.create<arith::CmpIOp>(
256255
loc, arith::CmpIPredicate::ult, pos, pHi);
@@ -261,7 +260,7 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
261260
// Load the next coordinates only when inbound (to avoid OOB
262261
// accesses).
263262
builder.setInsertionPointToStart(ifInBound.thenBlock());
264-
Value crd = genIndexLoad(builder, loc, coordinates, pos);
263+
Value crd = level.peekCrdAt(builder, loc, pos);
265264
Value isSameCrd = builder.create<arith::CmpIOp>(
266265
loc, arith::CmpIPredicate::eq, crd, sameCrd);
267266
YIELD(isSameCrd);
@@ -284,11 +283,8 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
284283

285284
Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
286285
Level lvl) {
287-
// A load on the coordinates array yields the coordinate.
288-
const Value mem = coordinatesBuffers[tid][lvl];
289-
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
290286
const Value pos = posits[tid][lvl];
291-
const Value crd = genIndexLoad(builder, loc, mem, pos);
287+
const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos);
292288
return crd;
293289
}
294290

@@ -318,9 +314,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
318314
this->segHi.assign(numTensors, std::vector<Value>());
319315
this->posits.assign(numTensors, std::vector<Value>());
320316
this->coords.assign(numTensors, std::vector<Value>());
321-
this->positionsBuffers.assign(numTensors, std::vector<Value>());
322-
this->coordinatesBuffers.assign(numTensors, std::vector<Value>());
323317
this->valBuffer.assign(numTensors, nullptr);
318+
this->lvls.resize(numTensors);
324319
this->isSparseSlices.assign(numTensors, false);
325320
this->sliceOffsets.assign(numTensors, std::vector<Value>());
326321
this->sliceStrides.assign(numTensors, std::vector<Value>());
@@ -377,8 +372,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
377372
segHi[tid].assign(lvlRank, Value());
378373
posits[tid].assign(lvlRank, Value());
379374
coords[tid].assign(lvlRank, Value());
380-
positionsBuffers[tid].assign(lvlRank, Value());
381-
coordinatesBuffers[tid].assign(lvlRank, Value());
375+
lvls[tid].resize(lvlRank);
376+
382377
sliceOffsets[tid].assign(lvlRank, Value());
383378
sliceStrides[tid].assign(lvlRank, Value());
384379

@@ -448,22 +443,7 @@ void LoopEmitter::initializeLoopEmit(
448443

449444
// Scan all levels of current tensor.
450445
for (Level l = 0; l < lvlRank; l++) {
451-
// This should be called only once at beginning.
452-
assert(!positionsBuffers[t][l] && !coordinatesBuffers[t][l] &&
453-
!highs[t][l]);
454-
const auto lvlTp = lvlTypes[t][l];
455-
// Handle sparse storage schemes.
456-
if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
457-
// Generate sparse primitives to obtain positions and coordinates.
458-
positionsBuffers[t][l] = genToPositions(builder, loc, tensor, l);
459-
coordinatesBuffers[t][l] = genToCoordinates(builder, loc, tensor, l);
460-
} else if (isSingletonLT(lvlTp) || is2OutOf4LT(lvlTp)) {
461-
// Singleton level, fetch coordinates.
462-
coordinatesBuffers[t][l] = genToCoordinates(builder, loc, tensor, l);
463-
} else {
464-
// Dense level, nothing to fetch.
465-
assert(isDenseLT(lvlTp));
466-
}
446+
lvls[t][l] = makeSparseTensorLevel(builder, loc, tensor, l);
467447

468448
// Find upper bound in current dimension.
469449
highs[t][l] = lvlSizes[t][l] = lvlSzs[l];
@@ -756,8 +736,7 @@ Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc,
756736
crdHi = ADDI(getMostRecentSliceOnLvl(tid, lvl).offset, remSz);
757737
}
758738
assert(crdHi);
759-
return genSparseReducedAffineCond(builder, loc,
760-
coordinatesBuffers[tid][lvl], crdHi,
739+
return genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl], crdHi,
761740
ivs[0], highs[tid][lvl]);
762741
}
763742
case LoopCondKind::SparseAffineUnRedCond: {
@@ -802,10 +781,9 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
802781
sliceTupleFwdCnt[tid][lvl] = SUBI(ivs[0], posits[tid][lvl]);
803782
// Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
804783
Value posit = ivs[0];
805-
Value crdBuf = coordinatesBuffers[tid][lvl];
806784
// We need to substract the offset to get relative coordinates.
807785
// TODO: Maybe assert relC >=0 during runtime in debug build?
808-
Value absC = genIndexLoad(builder, loc, crdBuf, posit);
786+
Value absC = lvls[tid][lvl]->peekCrdAt(builder, loc, posit);
809787
auto relC = SUBI(absC, getFinalSliceOnLvl(tid, lvl).offset);
810788
posits[tid][lvl] = posit;
811789
coords[tid][lvl] = relC;
@@ -1189,9 +1167,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
11891167
// The induction variable gives the position.
11901168
const Value pos = forOp.getInductionVar();
11911169
posits[tid][lvl] = pos;
1192-
// Generating a load on the coordinates array yields the crd.
1193-
const Value mem = coordinatesBuffers[tid][lvl];
1194-
const Value crd = genIndexLoad(builder, loc, mem, pos);
1170+
const Value crd = lvls[tid][lvl]->peekCrdAt(builder, loc, pos);
11951171
coords[tid][lvl] = crd;
11961172

11971173
// Generate an if-condition to filter out coordinates that are not
@@ -1255,7 +1231,11 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
12551231
/// FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
12561232
assert(lvl == 0 || posits[tid][lvl - 1]);
12571233
if (isCompressedLT(lvlTp) || isLooseCompressedLT(lvlTp)) {
1258-
const Value mem = positionsBuffers[tid][lvl];
1234+
// TODO: eliminate the cast upon feature complete.
1235+
const Value mem =
1236+
isCompressedLT(lvlTp)
1237+
? static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer
1238+
: static_cast<LooseCompressedLevel &>(*lvls[tid][lvl]).posBuffer;
12591239

12601240
Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1];
12611241
if (isLooseCompressedLT(lvlTp))
@@ -1623,8 +1603,7 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
16231603
/*beforeBuilder=*/
16241604
[this, posHi, sliceHi, tid, lvl](OpBuilder &builder, Location loc,
16251605
ValueRange args) {
1626-
Value cond = genSparseReducedAffineCond(builder, loc,
1627-
coordinatesBuffers[tid][lvl],
1606+
Value cond = genSparseReducedAffineCond(builder, loc, *lvls[tid][lvl],
16281607
sliceHi, args[0], posHi);
16291608
// continue if not yet break nor out of bound.
16301609
builder.create<scf::ConditionOp>(loc, cond, args);
@@ -1848,12 +1827,14 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
18481827
Value pHi, pLo;
18491828
if (lvl == 0) {
18501829
pLo = c0;
1851-
pHi = genIndexLoad(builder, loc, positionsBuffers[tid][0], c1);
1830+
// TODO: eliminate the cast upon feature complete.pLo = c0;
1831+
Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][0]).posBuffer;
1832+
pHi = genIndexLoad(builder, loc, pBuf, c1);
18521833
} else {
1853-
pLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
1854-
posits[tid][lvl - 1]);
1855-
pHi = genIndexLoad(builder, loc, positionsBuffers[tid][lvl],
1856-
ADDI(posits[tid][lvl - 1], c1));
1834+
// TODO: eliminate the cast upon feature complete.} else {
1835+
Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
1836+
pLo = genIndexLoad(builder, loc, pBuf, posits[tid][lvl - 1]);
1837+
pHi = genIndexLoad(builder, loc, pBuf, ADDI(posits[tid][lvl - 1], c1));
18571838
}
18581839
// Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
18591840
updateSlicePos(builder, loc, sPtrBuf, pLo, c0, SlicePosKind::kLo);
@@ -1868,7 +1849,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
18681849
// nonempty. though we assume that even on empty sparse tensors, a non-empty
18691850
// ptr/idx buffer is allocated for each level so it would not cause OOB to
18701851
// avoid generating a ifOp here.
1871-
Value minCrd = genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
1852+
Value minCrd = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
18721853

18731854
// FIXME: We need the relative offset related to the base slice.
18741855
Value absOffset = offsetFromMinCoord(builder, loc, minCrd, nxSz, isNonEmpty);
@@ -1955,9 +1936,10 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
19551936
Value &curTupleCnt = reduc[2];
19561937

19571938
Value pHi = ADDI(iv, c1);
1958-
Value sPLo = genIndexLoad(builder, loc, positionsBuffers[tid][lvl], iv);
1959-
Value sPHi =
1960-
genIndexLoad(builder, loc, positionsBuffers[tid][lvl], pHi);
1939+
// TODO: eliminate the cast upon feature complete.
1940+
Value pBuf = static_cast<CompressedLevel &>(*lvls[tid][lvl]).posBuffer;
1941+
Value sPLo = genIndexLoad(builder, loc, pBuf, iv);
1942+
Value sPHi = genIndexLoad(builder, loc, pBuf, pHi);
19611943

19621944
// isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is
19631945
// one non-empty lvl, the slice is non-empty.
@@ -1975,8 +1957,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
19751957
// }
19761958
OpBuilder::InsertionGuard guard(builder);
19771959
builder.setInsertionPointToStart(ifNonEmpty.thenBlock());
1978-
Value curC =
1979-
genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], sPLo);
1960+
Value curC = lvls[tid][lvl]->peekCrdAt(builder, loc, sPLo);
19801961
Value isSmaller = CMPI(ult, curC, minCrd);
19811962
Value newMin = SELECT(isSmaller, curC, minCrd);
19821963
YIELD(newMin);
@@ -2176,8 +2157,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
21762157
/* if pLo < pHi */ {
21772158
builder.setInsertionPointToStart(&advPLo.getThenRegion().front());
21782159
// coord = load[pLo]
2179-
Value coord =
2180-
genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo);
2160+
Value coord = lvls[tid][lvl]->peekCrdAt(builder, loc, pLo);
21812161
Value pred = CMPI(eq, coord, info.minCrd);
21822162
auto ifEqual = builder.create<scf::IfOp>(loc, idxTp, pred, true);
21832163
/* if coord == minCrd */ {
@@ -2209,7 +2189,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
22092189
auto newMin =
22102190
builder.create<scf::IfOp>(loc, idxTp, lvlNonEmpty, true);
22112191
builder.setInsertionPointToStart(&newMin.getThenRegion().front());
2212-
YIELD(genIndexLoad(builder, loc, coordinatesBuffers[tid][lvl], pLo));
2192+
YIELD(lvls[tid][lvl]->peekCrdAt(builder, loc, pLo));
22132193

22142194
builder.setInsertionPointToStart(&newMin.getElseRegion().front());
22152195
YIELD(curMinCrd);

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
#include <vector>
1313

14+
#include "SparseTensorLevels.h"
15+
1416
#include "mlir/Dialect/SparseTensor/IR/Enums.h"
1517
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1618
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
@@ -241,12 +243,6 @@ class LoopEmitter {
241243
const std::vector<std::vector<Value>> &getPosits() const { return posits; };
242244
const std::vector<std::vector<Value>> &getCoords() const { return coords; };
243245
const std::vector<std::vector<Value>> &getHighs() const { return highs; };
244-
const std::vector<std::vector<Value>> &getPositionBuffers() const {
245-
return positionsBuffers;
246-
};
247-
const std::vector<std::vector<Value>> &getCoordinateBuffers() const {
248-
return coordinatesBuffers;
249-
};
250246
const std::vector<Value> &getValBuffer() const { return valBuffer; };
251247

252248
constexpr static llvm::StringLiteral getLoopEmitterLoopAttrName() {
@@ -648,8 +644,9 @@ class LoopEmitter {
648644
std::vector<std::vector<Value>> segHi;
649645
std::vector<std::vector<Value>> highs;
650646
std::vector<std::vector<Value>> lvlSizes;
651-
std::vector<std::vector<Value>> positionsBuffers; // to_positions
652-
std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
647+
// std::vector<std::vector<Value>> positionsBuffers; // to_positions
648+
// std::vector<std::vector<Value>> coordinatesBuffers; // to_coordinates
649+
std::vector<std::vector<std::unique_ptr<SparseTensorLevel>>> lvls;
653650
std::vector<Value> valBuffer; // to_value
654651

655652
//
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include "SparseTensorLevels.h"
2+
#include "CodegenUtils.h"
3+
4+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
5+
6+
using namespace mlir;
7+
using namespace mlir::sparse_tensor;
8+
9+
std::unique_ptr<SparseTensorLevel>
10+
sparse_tensor::makeSparseTensorLevel(OpBuilder &builder, Location loc, Value t,
11+
Level l) {
12+
auto stt = getSparseTensorType(t);
13+
14+
LevelType lt = stt.getLvlType(l);
15+
Value lvlSz = stt.hasEncoding()
16+
? builder.create<LvlOp>(loc, t, l).getResult()
17+
: builder.create<tensor::DimOp>(loc, t, l).getResult();
18+
19+
switch (*getLevelFormat(lt)) {
20+
case LevelFormat::Dense:
21+
return std::make_unique<DenseLevel>(lvlSz);
22+
case LevelFormat::Compressed: {
23+
Value posBuf = genToPositions(builder, loc, t, l);
24+
Value crdBuf = genToCoordinates(builder, loc, t, l);
25+
return std::make_unique<CompressedLevel>(lt, lvlSz, posBuf, crdBuf);
26+
}
27+
case LevelFormat::LooseCompressed: {
28+
Value posBuf = genToPositions(builder, loc, t, l);
29+
Value crdBuf = genToCoordinates(builder, loc, t, l);
30+
return std::make_unique<LooseCompressedLevel>(lt, lvlSz, posBuf, crdBuf);
31+
}
32+
case LevelFormat::Singleton: {
33+
Value crdBuf = genToCoordinates(builder, loc, t, l);
34+
return std::make_unique<SingletonLevel>(lt, lvlSz, crdBuf);
35+
}
36+
case LevelFormat::TwoOutOfFour: {
37+
Value crdBuf = genToCoordinates(builder, loc, t, l);
38+
return std::make_unique<TwoOutFourLevel>(lt, lvlSz, crdBuf);
39+
}
40+
}
41+
llvm_unreachable("unrecognizable level format");
42+
}
43+
44+
Value SparseLevel::peekCrdAt(OpBuilder &b, Location l, Value pos) const {
45+
return genIndexLoad(b, l, crdBuffer, pos);
46+
}

0 commit comments

Comments
 (0)