Skip to content

Commit 1328bb6

Browse files
author
Peiming Liu
committed
[mlir][sparse] extend loop emitter and optimize lattices with the awareness of slice based iteration
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D142929
1 parent 8d024a7 commit 1328bb6

File tree

7 files changed

+201
-124
lines changed

7 files changed

+201
-124
lines changed

mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,11 +399,17 @@ class Merger {
399399
/// to sparse level-type.
400400
bool hasAnySparse(const BitVector &bits) const;
401401

402+
/// Returns true if bits contains a dependent index reduction condition on
403+
/// sparse levels.
404+
bool hasSparseIdxReduction(const BitVector &bits) const;
405+
402406
/// Gets the level-type of the `t`th tensor on `i`th loop.
403407
DimLevelType getDimLevelType(TensorId t, LoopId i) const {
404408
assert(t < numTensors && i < numLoops);
405409
return lvlTypes[t][i];
406410
}
411+
412+
/// Gets the level-type of the TensorLoopId.
407413
DimLevelType getDimLevelType(TensorLoopId b) const {
408414
return getDimLevelType(tensor(b), loop(b));
409415
}

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,23 @@ static bool isMaterializing(Value val) {
2828
val.getDefiningOp<bufferization::AllocTensorOp>();
2929
}
3030

31+
/// Makes target array's elements sorted according to the `order` array.
32+
static void sortArrayBasedOnOrder(std::vector<LoopId> &target,
33+
ArrayRef<LoopId> order) {
34+
std::sort(target.begin(), target.end(), [&order](LoopId l, LoopId r) {
35+
assert(l != r);
36+
int idxL = -1, idxR = -1;
37+
for (int i = 0, e = order.size(); i < e; i++) {
38+
if (order[i] == l)
39+
idxL = i;
40+
if (order[i] == r)
41+
idxR = i;
42+
}
43+
assert(idxL >= 0 && idxR >= 0);
44+
return idxL < idxR;
45+
});
46+
}
47+
3148
//===----------------------------------------------------------------------===//
3249
// Code generation environment constructor and general methods
3350
//===----------------------------------------------------------------------===//
@@ -57,15 +74,42 @@ void CodegenEnv::startEmit() {
5774
insChain = sparseOut->get();
5875
latticeMerger.setHasSparseOut(true);
5976
}
77+
78+
// Sort the related loop array such that they are in the same order as they
79+
// appears on the topoOrder.
80+
// TODO: since we only handle affine addition for slice based codegen, and
81+
// addition is assoicative, the order how we evaluate the expression does
82+
// not matter. However, to support multiplication, the order of the loop
83+
// index should match the evaluation order to the affine expression AST.
84+
6085
// Initialize loop emitter.
61-
SmallVector<Value> tensors;
62-
for (OpOperand &t : linalgOp->getOpOperands())
86+
SmallVector<Value> tensors; // input tensors passed to loop emitter
87+
for (OpOperand &t : linalgOp->getOpOperands()) {
6388
tensors.push_back(t.get());
64-
loopEmitter.initialize(tensors,
65-
StringAttr::get(linalgOp.getContext(),
66-
linalg::GenericOp::getOperationName()),
67-
/*hasOutput=*/true,
68-
/*isSparseOut=*/sparseOut != nullptr, topSort);
89+
Level rank = linalgOp.getMatchingIndexingMap(&t).getNumResults();
90+
for (Level lvl = 0; lvl < rank; lvl++) {
91+
sortArrayBasedOnOrder(
92+
latticeMerger.getDependentLoops(t.getOperandNumber(), lvl), topSort);
93+
}
94+
}
95+
96+
loopEmitter.initialize(
97+
tensors,
98+
StringAttr::get(linalgOp.getContext(),
99+
linalg::GenericOp::getOperationName()),
100+
/*hasOutput=*/true,
101+
/*isSparseOut=*/sparseOut != nullptr, topSort,
102+
// TODO: compute the map and pass it to loop emitter directly instead of
103+
// passing in a callback.
104+
[this](TensorId t, Level lvl) -> std::vector<std::pair<TensorId, Level>> {
105+
// Translates from a list of loop index to a list of [tid, dim] pair.
106+
std::vector<LoopId> rLoops = this->merger().getDependentLoops(t, lvl);
107+
std::vector<std::pair<TensorId, Level>> ret;
108+
ret.reserve(rLoops.size());
109+
for (LoopId l : rLoops)
110+
ret.emplace_back(this->merger().getLoopDefiningLvl(l));
111+
return ret;
112+
});
69113
}
70114

71115
std::optional<Operation *> CodegenEnv::genLoopBoundary(

mlir/lib/Dialect/SparseTensor/Transforms/CodegenEnv.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ class CodegenEnv {
9999
topSort.reserve(capacity);
100100
}
101101

102-
ArrayRef<LoopId> getTopSort() const { return topSort; };
103102
ArrayRef<LoopId> getTopSortSlice(LoopOrd n, LoopOrd m) const;
104103
ArrayRef<LoopId> getLoopStackUpTo(LoopOrd n) const;
105104
ArrayRef<LoopId> getCurrentLoopStack() const;

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,14 @@ Value LoopEmitter::genSparseCrd(OpBuilder &builder, Location loc, TensorId tid,
208208
}
209209

210210
LoopEmitter::LoopEmitter(ValueRange tensors, StringAttr loopTag, bool hasOutput,
211-
bool isSparseOut, ArrayRef<LoopId> topSort) {
212-
initialize(tensors, loopTag, hasOutput, isSparseOut, topSort);
211+
bool isSparseOut, ArrayRef<LoopId> topSort,
212+
DependentLvlGetter dimGetter) {
213+
initialize(tensors, loopTag, hasOutput, isSparseOut, topSort, dimGetter);
213214
}
214215

215216
void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
216-
bool isSparseOut, ArrayRef<LoopId> topSort) {
217+
bool isSparseOut, ArrayRef<LoopId> topSort,
218+
DependentLvlGetter dimGetter) {
217219
// First initialize the top-level type of the fields.
218220
this->loopTag = loopTag;
219221
this->hasOutput = hasOutput;
@@ -242,6 +244,9 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
242244
this->loopStack.reserve(numLoops);
243245
this->loopSeqStack.reserve(numLoops);
244246

247+
this->dependentLvlMap.assign(
248+
numTensors, std::vector<std::vector<std::pair<TensorId, Level>>>());
249+
245250
// Initialize nested types of `TensorId`-indexed fields.
246251
for (TensorId tid = 0; tid < numTensors; tid++) {
247252
const Value t = tensors[tid];
@@ -283,6 +288,12 @@ void LoopEmitter::initialize(ValueRange ts, StringAttr loopTag, bool hasOutput,
283288
coordinatesBuffers[tid].assign(lvlRank, Value());
284289
sliceOffsets[tid].assign(lvlRank, Value());
285290
sliceStrides[tid].assign(lvlRank, Value());
291+
292+
dependentLvlMap[tid].assign(lvlRank,
293+
std::vector<std::pair<TensorId, Level>>());
294+
if (dimGetter)
295+
for (Level l = 0; l < lvlRank; l++)
296+
dependentLvlMap[tid][l] = dimGetter(tid, l);
286297
}
287298

288299
// Construct the inverse of the `topSort` from the sparsifier.
@@ -997,8 +1008,8 @@ void LoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
9971008
}
9981009
}
9991010

1000-
void LoopEmitter::exitCoIterationLoop(OpBuilder &builder, Location loc,
1001-
MutableArrayRef<Value> reduc) {
1011+
void LoopEmitter::exitWhileLoop(OpBuilder &builder, Location loc,
1012+
MutableArrayRef<Value> reduc) {
10021013
const LoopInfo &loopInfo = loopStack.back();
10031014
auto whileOp = llvm::cast<scf::WhileOp>(loopInfo.loop);
10041015
builder.setInsertionPointToEnd(loopInfo.userCodeBlock);
@@ -1082,7 +1093,7 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
10821093
assert(loopInfo.tids.size() == loopInfo.lvls.size());
10831094
SmallVector<Value> red;
10841095
if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
1085-
exitCoIterationLoop(rewriter, loc, reduc);
1096+
exitWhileLoop(rewriter, loc, reduc);
10861097
} else {
10871098
exitForLoop(rewriter, loc, reduc);
10881099
}

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.h

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ class LoopEmitter {
7676
/// initializing the loop emitter (e.g., to fill a dense output with zeros).
7777
using OutputUpdater = function_ref<Value(OpBuilder &builder, Location loc,
7878
Value memref, Value tensor)>;
79+
// Map from [tid, dim] to a list of dependent [tid, dim] for affine expression
80+
// index on sparse tensors.
81+
// E.g., for affine index (d0 + d1), it depends on two [tid, dim] that defines
82+
// d0 and d1 (for affine expression reduction).
83+
// If the list is empty, it means that there is no affine expression on the
84+
// input [tid, dim].
85+
using DependentLvlGetter =
86+
function_ref<std::vector<std::pair<TensorId, Level>>(TensorId, Level)>;
7987

8088
LoopEmitter() = default;
8189

@@ -89,11 +97,13 @@ class LoopEmitter {
8997
/// to `LoopId`.
9098
void initialize(ValueRange tensors, StringAttr loopTag = nullptr,
9199
bool hasOutput = false, bool isSparseOut = false,
92-
ArrayRef<LoopId> topSort = {});
100+
ArrayRef<LoopId> topSort = {},
101+
DependentLvlGetter getter = nullptr);
93102

94103
explicit LoopEmitter(ValueRange tensors, StringAttr loopTag = nullptr,
95104
bool hasOutput = false, bool isSparseOut = false,
96-
ArrayRef<LoopId> topSort = {});
105+
ArrayRef<LoopId> topSort = {},
106+
DependentLvlGetter getter = nullptr);
97107

98108
/// Starts a loop emitting session by generating all the buffers needed
99109
/// for iterating over the tensors.
@@ -295,8 +305,8 @@ class LoopEmitter {
295305
MutableArrayRef<Value> reduc);
296306

297307
/// Exits a while loop, returns the reduction results.
298-
void exitCoIterationLoop(OpBuilder &builder, Location loc,
299-
MutableArrayRef<Value> reduc);
308+
void exitWhileLoop(OpBuilder &builder, Location loc,
309+
MutableArrayRef<Value> reduc);
300310

301311
//
302312
// View-based-reshape methods.
@@ -380,6 +390,15 @@ class LoopEmitter {
380390
std::vector<std::vector<Value>> sliceOffsets;
381391
std::vector<std::vector<Value>> sliceStrides;
382392

393+
// Map from [tid, level] to a list of dependent [tid, level].
394+
// See comments for `DependentDimGetter`.
395+
std::vector<std::vector<std::vector<std::pair<TensorId, Level>>>>
396+
dependentLvlMap;
397+
398+
//
399+
// View based reshape related-fields and methods
400+
//
401+
383402
/// Collapse Reassociations related to a specific tensor
384403
// TODO: support expand.
385404
std::vector<ArrayAttr> collapseReassoc;

0 commit comments

Comments
 (0)