@@ -126,15 +126,15 @@ static std::pair<Value, Value> fromSliceCrd(OpBuilder &builder, Location loc,
126
126
// Generates a bool value for while loop condition that tries to iterate over a
127
127
// fully reduced level with affine index expression.
128
128
static Value genSparseReducedAffineCond (OpBuilder &builder, Location loc,
129
- Value crdBuf, Value crdHi, Value posit ,
130
- Value posHi) {
129
+ const SparseTensorLevel &level ,
130
+ Value crdHi, Value posit, Value posHi) {
131
131
Value inBound = CMPI (ult, posit, posHi);
132
132
auto ifOp =
133
133
builder.create <scf::IfOp>(loc, builder.getI1Type (), inBound, true );
134
134
// if (inbound)
135
135
// yield coord < crdHi
136
136
builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
137
- Value crd = genIndexLoad (builder, loc, crdBuf , posit);
137
+ Value crd = level. peekCrdAt (builder, loc, posit);
138
138
YIELD (CMPI (ult, crd, crdHi));
139
139
// else
140
140
// yield false
@@ -244,13 +244,12 @@ Value LoopEmitter::genAddress(OpBuilder &builder, Location loc, TensorId tid,
244
244
Value LoopEmitter::genSegmentHigh (OpBuilder &builder, Location loc,
245
245
TensorId tid, Level lvl, Value pLo,
246
246
Value pHi) {
247
- const auto coordinates = coordinatesBuffers [tid][lvl];
248
- const auto sameCrd = genIndexLoad (builder, loc, coordinates , pLo);
247
+ SparseTensorLevel &level = *lvls [tid][lvl];
248
+ const Value sameCrd = level. peekCrdAt (builder, loc, pLo);
249
249
auto whileOp = builder.create <scf::WhileOp>(
250
250
loc, builder.getIndexType (), pLo,
251
251
/* beforeBuilder=*/
252
- [pHi, coordinates, sameCrd](OpBuilder &builder, Location loc,
253
- ValueRange ivs) {
252
+ [pHi, &level, sameCrd](OpBuilder &builder, Location loc, ValueRange ivs) {
254
253
const auto pos = ivs[0 ];
255
254
Value inBound = builder.create <arith::CmpIOp>(
256
255
loc, arith::CmpIPredicate::ult, pos, pHi);
@@ -261,7 +260,7 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
261
260
// Load the next coordinates only when inbound (to avoid OOB
262
261
// accesses).
263
262
builder.setInsertionPointToStart (ifInBound.thenBlock ());
264
- Value crd = genIndexLoad (builder, loc, coordinates , pos);
263
+ Value crd = level. peekCrdAt (builder, loc, pos);
265
264
Value isSameCrd = builder.create <arith::CmpIOp>(
266
265
loc, arith::CmpIPredicate::eq, crd, sameCrd);
267
266
YIELD (isSameCrd);
@@ -284,11 +283,8 @@ Value LoopEmitter::genSegmentHigh(OpBuilder &builder, Location loc,
284
283
285
284
Value LoopEmitter::genSparseCrd (OpBuilder &builder, Location loc, TensorId tid,
286
285
Level lvl) {
287
- // A load on the coordinates array yields the coordinate.
288
- const Value mem = coordinatesBuffers[tid][lvl];
289
- // / FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
290
286
const Value pos = posits[tid][lvl];
291
- const Value crd = genIndexLoad (builder, loc, mem , pos);
287
+ const Value crd = lvls[tid][lvl]-> peekCrdAt (builder, loc, pos);
292
288
return crd;
293
289
}
294
290
@@ -318,9 +314,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
318
314
this ->segHi .assign (numTensors, std::vector<Value>());
319
315
this ->posits .assign (numTensors, std::vector<Value>());
320
316
this ->coords .assign (numTensors, std::vector<Value>());
321
- this ->positionsBuffers .assign (numTensors, std::vector<Value>());
322
- this ->coordinatesBuffers .assign (numTensors, std::vector<Value>());
323
317
this ->valBuffer .assign (numTensors, nullptr );
318
+ this ->lvls .resize (numTensors);
324
319
this ->isSparseSlices .assign (numTensors, false );
325
320
this ->sliceOffsets .assign (numTensors, std::vector<Value>());
326
321
this ->sliceStrides .assign (numTensors, std::vector<Value>());
@@ -377,8 +372,8 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
377
372
segHi[tid].assign (lvlRank, Value ());
378
373
posits[tid].assign (lvlRank, Value ());
379
374
coords[tid].assign (lvlRank, Value ());
380
- positionsBuffers [tid].assign (lvlRank, Value () );
381
- coordinatesBuffers[tid]. assign (lvlRank, Value ());
375
+ lvls [tid].resize (lvlRank);
376
+
382
377
sliceOffsets[tid].assign (lvlRank, Value ());
383
378
sliceStrides[tid].assign (lvlRank, Value ());
384
379
@@ -448,22 +443,7 @@ void LoopEmitter::initializeLoopEmit(
448
443
449
444
// Scan all levels of current tensor.
450
445
for (Level l = 0 ; l < lvlRank; l++) {
451
- // This should be called only once at beginning.
452
- assert (!positionsBuffers[t][l] && !coordinatesBuffers[t][l] &&
453
- !highs[t][l]);
454
- const auto lvlTp = lvlTypes[t][l];
455
- // Handle sparse storage schemes.
456
- if (isCompressedLT (lvlTp) || isLooseCompressedLT (lvlTp)) {
457
- // Generate sparse primitives to obtain positions and coordinates.
458
- positionsBuffers[t][l] = genToPositions (builder, loc, tensor, l);
459
- coordinatesBuffers[t][l] = genToCoordinates (builder, loc, tensor, l);
460
- } else if (isSingletonLT (lvlTp) || is2OutOf4LT (lvlTp)) {
461
- // Singleton level, fetch coordinates.
462
- coordinatesBuffers[t][l] = genToCoordinates (builder, loc, tensor, l);
463
- } else {
464
- // Dense level, nothing to fetch.
465
- assert (isDenseLT (lvlTp));
466
- }
446
+ lvls[t][l] = makeSparseTensorLevel (builder, loc, tensor, l);
467
447
468
448
// Find upper bound in current dimension.
469
449
highs[t][l] = lvlSizes[t][l] = lvlSzs[l];
@@ -756,8 +736,7 @@ Value LoopEmitter::genWhileLoopConditions(OpBuilder &builder, Location loc,
756
736
crdHi = ADDI (getMostRecentSliceOnLvl (tid, lvl).offset , remSz);
757
737
}
758
738
assert (crdHi);
759
- return genSparseReducedAffineCond (builder, loc,
760
- coordinatesBuffers[tid][lvl], crdHi,
739
+ return genSparseReducedAffineCond (builder, loc, *lvls[tid][lvl], crdHi,
761
740
ivs[0 ], highs[tid][lvl]);
762
741
}
763
742
case LoopCondKind::SparseAffineUnRedCond: {
@@ -802,10 +781,9 @@ std::optional<Value> LoopEmitter::genWhileLoopBody(OpBuilder &builder,
802
781
sliceTupleFwdCnt[tid][lvl] = SUBI (ivs[0 ], posits[tid][lvl]);
803
782
// Update c = absOffset[lvl][depth] - absOffset[lvl][depth - 1]
804
783
Value posit = ivs[0 ];
805
- Value crdBuf = coordinatesBuffers[tid][lvl];
806
784
// We need to substract the offset to get relative coordinates.
807
785
// TODO: Maybe assert relC >=0 during runtime in debug build?
808
- Value absC = genIndexLoad (builder, loc, crdBuf , posit);
786
+ Value absC = lvls[tid][lvl]-> peekCrdAt (builder, loc, posit);
809
787
auto relC = SUBI (absC, getFinalSliceOnLvl (tid, lvl).offset );
810
788
posits[tid][lvl] = posit;
811
789
coords[tid][lvl] = relC;
@@ -1189,9 +1167,7 @@ Operation *LoopEmitter::enterFilterLoopOverTensorAtLvl(
1189
1167
// The induction variable gives the position.
1190
1168
const Value pos = forOp.getInductionVar ();
1191
1169
posits[tid][lvl] = pos;
1192
- // Generating a load on the coordinates array yields the crd.
1193
- const Value mem = coordinatesBuffers[tid][lvl];
1194
- const Value crd = genIndexLoad (builder, loc, mem, pos);
1170
+ const Value crd = lvls[tid][lvl]->peekCrdAt (builder, loc, pos);
1195
1171
coords[tid][lvl] = crd;
1196
1172
1197
1173
// Generate an if-condition to filter out coordinates that are not
@@ -1255,7 +1231,11 @@ void LoopEmitter::prepareLoopOverTensorAtLvl(OpBuilder &builder, Location loc,
1255
1231
// / FIXME: See the [CLARIFY_POSITS_LVL] note in the header.
1256
1232
assert (lvl == 0 || posits[tid][lvl - 1 ]);
1257
1233
if (isCompressedLT (lvlTp) || isLooseCompressedLT (lvlTp)) {
1258
- const Value mem = positionsBuffers[tid][lvl];
1234
+ // TODO: eliminate the cast upon feature complete.
1235
+ const Value mem =
1236
+ isCompressedLT (lvlTp)
1237
+ ? static_cast <CompressedLevel &>(*lvls[tid][lvl]).posBuffer
1238
+ : static_cast <LooseCompressedLevel &>(*lvls[tid][lvl]).posBuffer ;
1259
1239
1260
1240
Value pLo = lvl == 0 ? c0 : posits[tid][lvl - 1 ];
1261
1241
if (isLooseCompressedLT (lvlTp))
@@ -1623,8 +1603,7 @@ std::pair<Operation *, ValueRange> LoopEmitter::genSliceLvlTraverseLoop(
1623
1603
/* beforeBuilder=*/
1624
1604
[this , posHi, sliceHi, tid, lvl](OpBuilder &builder, Location loc,
1625
1605
ValueRange args) {
1626
- Value cond = genSparseReducedAffineCond (builder, loc,
1627
- coordinatesBuffers[tid][lvl],
1606
+ Value cond = genSparseReducedAffineCond (builder, loc, *lvls[tid][lvl],
1628
1607
sliceHi, args[0 ], posHi);
1629
1608
// continue if not yet break nor out of bound.
1630
1609
builder.create <scf::ConditionOp>(loc, cond, args);
@@ -1848,12 +1827,14 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
1848
1827
Value pHi, pLo;
1849
1828
if (lvl == 0 ) {
1850
1829
pLo = c0;
1851
- pHi = genIndexLoad (builder, loc, positionsBuffers[tid][0 ], c1);
1830
+ // TODO: eliminate the cast upon feature complete.pLo = c0;
1831
+ Value pBuf = static_cast <CompressedLevel &>(*lvls[tid][0 ]).posBuffer ;
1832
+ pHi = genIndexLoad (builder, loc, pBuf, c1);
1852
1833
} else {
1853
- pLo = genIndexLoad (builder, loc, positionsBuffers[tid][lvl],
1854
- posits [tid][lvl - 1 ]) ;
1855
- pHi = genIndexLoad (builder, loc, positionsBuffers [tid][lvl],
1856
- ADDI (posits[tid][lvl - 1 ], c1));
1834
+ // TODO: eliminate the cast upon feature complete.} else {
1835
+ Value pBuf = static_cast <CompressedLevel &>(*lvls [tid][lvl]). posBuffer ;
1836
+ pLo = genIndexLoad (builder, loc, pBuf, posits [tid][lvl - 1 ]);
1837
+ pHi = genIndexLoad (builder, loc, pBuf, ADDI (posits[tid][lvl - 1 ], c1));
1857
1838
}
1858
1839
// Fills out pIdxBuffer[tid][lvl][0] with [pLo, pHi]
1859
1840
updateSlicePos (builder, loc, sPtrBuf , pLo, c0, SlicePosKind::kLo );
@@ -1868,7 +1849,7 @@ void LoopEmitter::genResolvedSliceBegin(OpBuilder &builder, Location loc,
1868
1849
// nonempty. though we assume that even on empty sparse tensors, a non-empty
1869
1850
// ptr/idx buffer is allocated for each level so it would not cause OOB to
1870
1851
// avoid generating a ifOp here.
1871
- Value minCrd = genIndexLoad (builder, loc, coordinatesBuffers [tid][lvl], pLo);
1852
+ Value minCrd = lvls [tid][lvl]-> peekCrdAt (builder, loc , pLo);
1872
1853
1873
1854
// FIXME: We need the relative offset related to the base slice.
1874
1855
Value absOffset = offsetFromMinCoord (builder, loc, minCrd, nxSz, isNonEmpty);
@@ -1955,9 +1936,10 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
1955
1936
Value &curTupleCnt = reduc[2 ];
1956
1937
1957
1938
Value pHi = ADDI (iv, c1);
1958
- Value sPLo = genIndexLoad (builder, loc, positionsBuffers[tid][lvl], iv);
1959
- Value sPHi =
1960
- genIndexLoad (builder, loc, positionsBuffers[tid][lvl], pHi);
1939
+ // TODO: eliminate the cast upon feature complete.
1940
+ Value pBuf = static_cast <CompressedLevel &>(*lvls[tid][lvl]).posBuffer ;
1941
+ Value sPLo = genIndexLoad (builder, loc, pBuf, iv);
1942
+ Value sPHi = genIndexLoad (builder, loc, pBuf, pHi);
1961
1943
1962
1944
// isNonEmpty = isNonEmpty || lvlNonEmpty, i.e., as long as there is
1963
1945
// one non-empty lvl, the slice is non-empty.
@@ -1975,8 +1957,7 @@ void LoopEmitter::genUnResolvedSliceBegin(OpBuilder &builder, Location loc,
1975
1957
// }
1976
1958
OpBuilder::InsertionGuard guard (builder);
1977
1959
builder.setInsertionPointToStart (ifNonEmpty.thenBlock ());
1978
- Value curC =
1979
- genIndexLoad (builder, loc, coordinatesBuffers[tid][lvl], sPLo );
1960
+ Value curC = lvls[tid][lvl]->peekCrdAt (builder, loc, sPLo );
1980
1961
Value isSmaller = CMPI (ult, curC, minCrd);
1981
1962
Value newMin = SELECT (isSmaller, curC, minCrd);
1982
1963
YIELD (newMin);
@@ -2176,8 +2157,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
2176
2157
/* if pLo < pHi */ {
2177
2158
builder.setInsertionPointToStart (&advPLo.getThenRegion ().front ());
2178
2159
// coord = load[pLo]
2179
- Value coord =
2180
- genIndexLoad (builder, loc, coordinatesBuffers[tid][lvl], pLo);
2160
+ Value coord = lvls[tid][lvl]->peekCrdAt (builder, loc, pLo);
2181
2161
Value pred = CMPI (eq, coord, info.minCrd );
2182
2162
auto ifEqual = builder.create <scf::IfOp>(loc, idxTp, pred, true );
2183
2163
/* if coord == minCrd */ {
@@ -2209,7 +2189,7 @@ LoopEmitter::genSliceNextInduction(OpBuilder &builder, Location loc,
2209
2189
auto newMin =
2210
2190
builder.create <scf::IfOp>(loc, idxTp, lvlNonEmpty, true );
2211
2191
builder.setInsertionPointToStart (&newMin.getThenRegion ().front ());
2212
- YIELD (genIndexLoad (builder, loc, coordinatesBuffers [tid][lvl], pLo));
2192
+ YIELD (lvls [tid][lvl]-> peekCrdAt (builder, loc , pLo));
2213
2193
2214
2194
builder.setInsertionPointToStart (&newMin.getElseRegion ().front ());
2215
2195
YIELD (curMinCrd);
0 commit comments