Skip to content

Commit 35fae04

Browse files
authored
[mlir][sparse] using non-static field to avoid data races. (#81165)
1 parent 17f0680 commit 35fae04

File tree

4 files changed

+50
-34
lines changed

4 files changed

+50
-34
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
9494
this->loopTag = loopTag;
9595
this->hasOutput = hasOutput;
9696
this->isSparseOut = isSparseOut;
97-
SparseIterator::setSparseEmitStrategy(emitStrategy);
97+
this->emitStrategy = emitStrategy;
9898

9999
const unsigned numManifestTensors = ts.size();
100100
const unsigned synTensorId = numManifestTensors;
@@ -166,13 +166,13 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
166166
std::unique_ptr<SparseIterator>
167167
LoopEmitter::makeLevelIterator(OpBuilder &builder, Location loc, TensorId t,
168168
Level l) {
169-
auto it = makeSimpleIterator(*lvls[t][l]);
169+
auto it = makeSimpleIterator(*lvls[t][l], emitStrategy);
170170
auto stt = getSparseTensorType(tensors[t]);
171171
if (stt.hasEncoding() && stt.getEncoding().isSlice()) {
172172
Value offset = genSliceOffset(builder, loc, tensors[t], l);
173173
Value stride = genSliceStride(builder, loc, tensors[t], l);
174-
auto slicedIt = makeSlicedLevelIterator(std::move(it), offset, stride,
175-
lvls[t][l]->getSize());
174+
auto slicedIt = makeSlicedLevelIterator(
175+
std::move(it), offset, stride, lvls[t][l]->getSize(), emitStrategy);
176176
return slicedIt;
177177
}
178178
return it;
@@ -186,7 +186,7 @@ void LoopEmitter::initializeLoopEmit(
186186
TensorId synId = getSynTensorId();
187187
for (unsigned i = 0, e = loopHighs.size(); i < e; i++) {
188188
Value sz = loopHighs[i] = synSetter(builder, loc, i);
189-
auto [stl, it] = makeSynLevelAndIterator(sz, synId, i);
189+
auto [stl, it] = makeSynLevelAndIterator(sz, synId, i, emitStrategy);
190190
lvls[synId][i] = std::move(stl);
191191
iters[synId][i].emplace_back(std::move(it));
192192
}
@@ -317,12 +317,13 @@ void LoopEmitter::initSubSectIterator(OpBuilder &builder, Location loc) {
317317
size = ADDI(size, ADDI(MULI(idxMax, C_IDX(stride)), C_IDX(1)));
318318
}
319319
it = makeNonEmptySubSectIterator(builder, loc, parent, loopHighs[loop],
320-
std::move(lvlIt), size, curDep.second);
320+
std::move(lvlIt), size, curDep.second,
321+
emitStrategy);
321322
} else {
322323
const SparseIterator &subSectIter = *iters[t][lvl].back();
323324
it = makeTraverseSubSectIterator(builder, loc, subSectIter, *parent,
324325
std::move(lvlIt), loopHighs[loop],
325-
curDep.second);
326+
curDep.second, emitStrategy);
326327
}
327328
lastIter[t] = it.get();
328329
iters[t][lvl].emplace_back(std::move(it));

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ class LoopEmitter {
380380
/// tensor.
381381
bool hasOutput;
382382
bool isSparseOut;
383+
SparseEmitStrategy emitStrategy;
383384

384385
//
385386
// Fields which have `numTensor` many entries.

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -773,9 +773,6 @@ class SubSectIterator : public SparseIterator {
773773
// SparseIterator derived classes implementation.
774774
//===----------------------------------------------------------------------===//
775775

776-
SparseEmitStrategy SparseIterator::emitStrategy =
777-
SparseEmitStrategy::kFunctional;
778-
779776
void SparseIterator::genInit(OpBuilder &b, Location l,
780777
const SparseIterator *p) {
781778
if (emitStrategy == SparseEmitStrategy::kDebugInterface) {
@@ -1303,27 +1300,38 @@ sparse_tensor::makeSparseTensorLevel(OpBuilder &b, Location l, Value t,
13031300
}
13041301

13051302
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
1306-
sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl) {
1303+
sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
1304+
SparseEmitStrategy strategy) {
13071305
auto stl = std::make_unique<DenseLevel>(tid, lvl, sz, /*encoded=*/false);
13081306
auto it = std::make_unique<TrivialIterator>(*stl);
1307+
it->setSparseEmitStrategy(strategy);
13091308
return std::make_pair(std::move(stl), std::move(it));
13101309
}
13111310

13121311
std::unique_ptr<SparseIterator>
1313-
sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl) {
1312+
sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
1313+
SparseEmitStrategy strategy) {
1314+
std::unique_ptr<SparseIterator> ret;
13141315
if (!isUniqueLT(stl.getLT())) {
13151316
// We always dedupliate the non-unique level, but we should optimize it away
13161317
// if possible.
1317-
return std::make_unique<DedupIterator>(stl);
1318+
ret = std::make_unique<DedupIterator>(stl);
1319+
} else {
1320+
ret = std::make_unique<TrivialIterator>(stl);
13181321
}
1319-
return std::make_unique<TrivialIterator>(stl);
1322+
ret->setSparseEmitStrategy(strategy);
1323+
return ret;
13201324
}
13211325

13221326
std::unique_ptr<SparseIterator>
13231327
sparse_tensor::makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit,
1324-
Value offset, Value stride, Value size) {
1328+
Value offset, Value stride, Value size,
1329+
SparseEmitStrategy strategy) {
13251330

1326-
return std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1331+
auto ret =
1332+
std::make_unique<FilterIterator>(std::move(sit), offset, stride, size);
1333+
ret->setSparseEmitStrategy(strategy);
1334+
return ret;
13271335
}
13281336

13291337
static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
@@ -1335,38 +1343,42 @@ static const SparseIterator *tryUnwrapFilter(const SparseIterator *it) {
13351343

13361344
std::unique_ptr<SparseIterator> sparse_tensor::makeNonEmptySubSectIterator(
13371345
OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
1338-
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride) {
1346+
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
1347+
SparseEmitStrategy strategy) {
13391348

13401349
// Try unwrap the NonEmptySubSectIterator from a filter parent.
13411350
parent = tryUnwrapFilter(parent);
1342-
auto it = std::make_unique<NonEmptySubSectIterator>(
1343-
b, l, parent, std::move(delegate), size);
1351+
std::unique_ptr<SparseIterator> it =
1352+
std::make_unique<NonEmptySubSectIterator>(b, l, parent,
1353+
std::move(delegate), size);
13441354

13451355
if (stride != 1) {
13461356
// TODO: We can safely skip bound checking on sparse levels, but for dense
13471357
// iteration space, we need the bound to infer the dense loop range.
1348-
return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1349-
C_IDX(stride), /*size=*/loopBound);
1358+
it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1359+
C_IDX(stride), /*size=*/loopBound);
13501360
}
1361+
it->setSparseEmitStrategy(strategy);
13511362
return it;
13521363
}
13531364

13541365
std::unique_ptr<SparseIterator> sparse_tensor::makeTraverseSubSectIterator(
13551366
OpBuilder &b, Location l, const SparseIterator &subSectIter,
13561367
const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
1357-
Value loopBound, unsigned stride) {
1368+
Value loopBound, unsigned stride, SparseEmitStrategy strategy) {
13581369

13591370
// This must be a subsection iterator or a filtered subsection iterator.
13601371
auto &subSect =
13611372
llvm::cast<NonEmptySubSectIterator>(*tryUnwrapFilter(&subSectIter));
13621373

1363-
auto it = std::make_unique<SubSectIterator>(
1374+
std::unique_ptr<SparseIterator> it = std::make_unique<SubSectIterator>(
13641375
subSect, *tryUnwrapFilter(&parent), std::move(wrap));
13651376

13661377
if (stride != 1) {
1367-
return std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1368-
C_IDX(stride), /*size=*/loopBound);
1378+
it = std::make_unique<FilterIterator>(std::move(it), /*offset=*/C_IDX(0),
1379+
C_IDX(stride), /*size=*/loopBound);
13691380
}
1381+
it->setSparseEmitStrategy(strategy);
13701382
return it;
13711383
}
13721384

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ class SparseIterator {
111111
public:
112112
virtual ~SparseIterator() = default;
113113

114-
static void setSparseEmitStrategy(SparseEmitStrategy strategy) {
115-
SparseIterator::emitStrategy = strategy;
114+
void setSparseEmitStrategy(SparseEmitStrategy strategy) {
115+
emitStrategy = strategy;
116116
}
117117

118118
virtual std::string getDebugInterfacePrefix() const = 0;
@@ -248,7 +248,7 @@ class SparseIterator {
248248
return ref.take_front(cursorValsCnt);
249249
}
250250

251-
static SparseEmitStrategy emitStrategy;
251+
SparseEmitStrategy emitStrategy;
252252

253253
public:
254254
const IterKind kind; // For LLVM-style RTTI.
@@ -277,32 +277,34 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &builder,
277277

278278
/// Helper function to create a simple SparseIterator object that iterate over
279279
/// the SparseTensorLevel.
280-
std::unique_ptr<SparseIterator>
281-
makeSimpleIterator(const SparseTensorLevel &stl);
280+
std::unique_ptr<SparseIterator> makeSimpleIterator(const SparseTensorLevel &stl,
281+
SparseEmitStrategy strategy);
282282

283283
/// Helper function to create a synthetic SparseIterator object that iterate
284284
/// over a dense space specified by [0,`sz`).
285285
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
286-
makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl);
286+
makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
287+
SparseEmitStrategy strategy);
287288

288289
/// Helper function to create a SparseIterator object that iterate over a
289290
/// sliced space, the orignal space (before slicing) is traversed by `sit`.
290291
std::unique_ptr<SparseIterator>
291292
makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
292-
Value stride, Value size);
293+
Value stride, Value size, SparseEmitStrategy strategy);
293294

294295
/// Helper function to create a SparseIterator object that iterate over the
295296
/// non-empty subsections set.
296297
std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
297298
OpBuilder &b, Location l, const SparseIterator *parent, Value loopBound,
298-
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride);
299+
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride,
300+
SparseEmitStrategy strategy);
299301

300302
/// Helper function to create a SparseIterator object that iterate over a
301303
/// non-empty subsection created by NonEmptySubSectIterator.
302304
std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
303305
OpBuilder &b, Location l, const SparseIterator &subsectIter,
304306
const SparseIterator &parent, std::unique_ptr<SparseIterator> &&wrap,
305-
Value loopBound, unsigned stride);
307+
Value loopBound, unsigned stride, SparseEmitStrategy strategy);
306308

307309
} // namespace sparse_tensor
308310
} // namespace mlir

0 commit comments

Comments
 (0)