@@ -773,9 +773,6 @@ class SubSectIterator : public SparseIterator {
773
773
// SparseIterator derived classes implementation.
774
774
// ===----------------------------------------------------------------------===//
775
775
776
- SparseEmitStrategy SparseIterator::emitStrategy =
777
- SparseEmitStrategy::kFunctional ;
778
-
779
776
void SparseIterator::genInit (OpBuilder &b, Location l,
780
777
const SparseIterator *p) {
781
778
if (emitStrategy == SparseEmitStrategy::kDebugInterface ) {
@@ -1303,27 +1300,38 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
1303
1300
}
1304
1301
1305
1302
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1306
- sparse_tensor::makeSynLevelAndIterator (Value sz, unsigned tid, unsigned lvl) {
1303
+ sparse_tensor::makeSynLevelAndIterator (Value sz, unsigned tid, unsigned lvl,
1304
+ SparseEmitStrategy strategy) {
1307
1305
auto stl = std::make_unique<DenseLevel>(tid, lvl, sz, /* encoded=*/ false );
1308
1306
auto it = std::make_unique<TrivialIterator>(*stl);
1307
+ it->setSparseEmitStrategy (strategy);
1309
1308
return std::make_pair (std::move (stl), std::move (it));
1310
1309
}
1311
1310
1312
1311
std::unique_ptr<SparseIterator>
1313
- sparse_tensor::makeSimpleIterator (const SparseTensorLevel &stl) {
1312
+ sparse_tensor::makeSimpleIterator (const SparseTensorLevel &stl,
1313
+ SparseEmitStrategy strategy) {
1314
+ std::unique_ptr<SparseIterator> ret;
1314
1315
if (!isUniqueLT (stl.getLT ())) {
1315
1316
// We always dedupliate the non-unique level, but we should optimize it away
1316
1317
// if possible.
1317
- return std::make_unique<DedupIterator>(stl);
1318
+ ret = std::make_unique<DedupIterator>(stl);
1319
+ } else {
1320
+ ret = std::make_unique<TrivialIterator>(stl);
1318
1321
}
1319
- return std::make_unique<TrivialIterator>(stl);
1322
+ ret->setSparseEmitStrategy (strategy);
1323
+ return ret;
1320
1324
}
1321
1325
1322
1326
std::unique_ptr<SparseIterator>
1323
1327
sparse_tensor::makeSlicedLevelIterator (std::unique_ptr<SparseIterator> &&sit,
1324
- Value offset, Value stride, Value size) {
1328
+ Value offset, Value stride, Value size,
1329
+ SparseEmitStrategy strategy) {
1325
1330
1326
- return std::make_unique<FilterIterator>(std::move (sit), offset, stride, size);
1331
+ auto ret =
1332
+ std::make_unique<FilterIterator>(std::move (sit), offset, stride, size);
1333
+ ret->setSparseEmitStrategy (strategy);
1334
+ return ret;
1327
1335
}
1328
1336
1329
1337
static const SparseIterator *tryUnwrapFilter (const SparseIterator *it) {
@@ -1335,38 +1343,42 @@ static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
1335
1343
1336
1344
std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator (
1337
1345
OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1338
- std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
1346
+ std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1347
+ SparseEmitStrategy strategy) {
1339
1348
1340
1349
// Try unwrap the NonEmptySubSectIterator from a filter parent.
1341
1350
parent = tryUnwrapFilter (parent);
1342
- auto it = std::make_unique<NonEmptySubSectIterator>(
1343
- b, l, parent, std::move (delegate), size);
1351
+ std::unique_ptr<SparseIterator> it =
1352
+ std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1353
+ std::move (delegate), size);
1344
1354
1345
1355
if (stride != 1 ) {
1346
1356
// TODO: We can safely skip bound checking on sparse levels, but for dense
1347
1357
// iteration space, we need the bound to infer the dense loop range.
1348
- return std::make_unique<FilterIterator>(std::move (it), /* offset=*/ C_IDX (0 ),
1349
- C_IDX (stride), /* size=*/ loopBound);
1358
+ it = std::make_unique<FilterIterator>(std::move (it), /* offset=*/ C_IDX (0 ),
1359
+ C_IDX (stride), /* size=*/ loopBound);
1350
1360
}
1361
+ it->setSparseEmitStrategy (strategy);
1351
1362
return it;
1352
1363
}
1353
1364
1354
1365
std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator (
1355
1366
OpBuilder &b, Location l, const SparseIterator &subSectIter,
1356
1367
const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1357
- Value loopBound, unsigned stride) {
1368
+ Value loopBound, unsigned stride, SparseEmitStrategy strategy ) {
1358
1369
1359
1370
// This must be a subsection iterator or a filtered subsection iterator.
1360
1371
auto &subSect =
1361
1372
llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter (&subSectIter));
1362
1373
1363
- auto it = std::make_unique<SubSectIterator>(
1374
+ std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
1364
1375
subSect, *tryUnwrapFilter (&parent), std::move (wrap));
1365
1376
1366
1377
if (stride != 1 ) {
1367
- return std::make_unique<FilterIterator>(std::move (it), /* offset=*/ C_IDX (0 ),
1368
- C_IDX (stride), /* size=*/ loopBound);
1378
+ it = std::make_unique<FilterIterator>(std::move (it), /* offset=*/ C_IDX (0 ),
1379
+ C_IDX (stride), /* size=*/ loopBound);
1369
1380
}
1381
+ it->setSparseEmitStrategy (strategy);
1370
1382
return it;
1371
1383
}
1372
1384
0 commit comments