Skip to content

Commit 8383897

Browse files
author
Peiming Liu
committed
[mlir][sparse] support Parallel for/reduction.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D135927
1 parent 03c7cd3 commit 8383897

File tree

7 files changed

+285
-127
lines changed

7 files changed

+285
-127
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 111 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,12 @@ Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
219219
OpBuilder &builder, Location loc, size_t tid, size_t dim,
220220
MutableArrayRef<Value> reduc, bool isParallel, ArrayRef<size_t> extraTids,
221221
ArrayRef<size_t> extraDims) {
222+
222223
assert(dimTypes[tid].size() > dim);
223224
// We can not re-enter the same level.
224225
assert(!coord[tid][dim]);
226+
// TODO: support multiple return on parallel for?
227+
assert(!isParallel || reduc.empty() <= 1);
225228

226229
Value step = constantIndex(builder, loc, 1);
227230
auto dimType = dimTypes[tid][dim];
@@ -232,11 +235,38 @@ Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
232235
Value lo = isSparseInput ? pidxs[tid][dim] // current offset
233236
: loopSeqStack.back(); // univeral tid
234237
Value hi = highs[tid][dim];
238+
Operation *loop = nullptr;
239+
Value iv;
240+
if (isParallel) {
241+
scf::ParallelOp parOp =
242+
builder.create<scf::ParallelOp>(loc, lo, hi, step, reduc);
243+
builder.setInsertionPointToStart(parOp.getBody());
244+
assert(parOp.getNumReductions() == reduc.size());
245+
iv = parOp.getInductionVars()[0];
246+
247+
// In-place update on the reduction variable vector.
248+
// Note that the init vals is not the actual reduction variables but instead
249+
// used as a `special handle` to (temporarily) represent them. The
250+
// expression on init vals will be moved into scf.reduce and replaced with
251+
// the block arguments when exiting the loop (see exitForLoop). This is
252+
// needed as we can not build the actual reduction block and get the actual
253+
// reduction varaible before users fill parallel loop body.
254+
for (int i = 0, e = reduc.size(); i < e; i++)
255+
reduc[i] = parOp.getInitVals()[i];
256+
loop = parOp;
257+
} else {
258+
scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
259+
builder.setInsertionPointToStart(forOp.getBody());
260+
iv = forOp.getInductionVar();
261+
262+
// In-place update on the reduction variable vector.
263+
assert(forOp.getNumRegionIterArgs() == reduc.size());
264+
for (int i = 0, e = reduc.size(); i < e; i++)
265+
reduc[i] = forOp.getRegionIterArg(i);
266+
loop = forOp;
267+
}
268+
assert(loop && iv);
235269

236-
scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, reduc);
237-
builder.setInsertionPointToStart(forOp.getBody());
238-
Value iv = forOp.getInductionVar();
239-
assert(iv);
240270
if (isSparseInput) {
241271
pidxs[tid][dim] = iv;
242272
// Generating a load on the indices array yields the coordinate.
@@ -253,16 +283,12 @@ Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
253283

254284
// NOTE: we can also prepares for next dim here in advance
255285
// Push the loop into stack
256-
loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), forOp,
286+
loopStack.emplace_back(ArrayRef<size_t>(tid), ArrayRef<size_t>(dim), loop,
257287
coord[tid][dim]);
258288
// Emit extra locals.
259289
emitExtraLocalsForTensorsAtDenseDims(builder, loc, extraTids, extraDims);
260290

261-
// In-place update on the reduction variable vector.
262-
assert(forOp.getNumRegionIterArgs() == reduc.size());
263-
for (int i = 0, e = reduc.size(); i < e; i++)
264-
reduc[i] = forOp.getRegionIterArg(i);
265-
return forOp;
291+
return loop;
266292
}
267293

268294
Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims(
@@ -434,17 +460,73 @@ void SparseTensorLoopEmitter::emitExtraLocalsForTensorsAtDenseDims(
434460
}
435461
}
436462

437-
SmallVector<Value, 2>
438-
SparseTensorLoopEmitter::exitForLoop(OpBuilder &builder, Location loc,
439-
ArrayRef<Value> reduc) {
463+
void SparseTensorLoopEmitter::exitForLoop(RewriterBase &rewriter, Location loc,
464+
MutableArrayRef<Value> reduc) {
440465
LoopLevelInfo &loopInfo = loopStack.back();
441466
auto &dims = loopStack.back().dims;
442467
auto &tids = loopStack.back().tids;
443-
auto forOp = llvm::cast<scf::ForOp>(loopInfo.loop);
444-
if (!reduc.empty()) {
445-
assert(reduc.size() == forOp.getNumResults());
446-
builder.setInsertionPointToEnd(forOp.getBody());
447-
builder.create<scf::YieldOp>(loc, reduc);
468+
auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop);
469+
if (forOp) {
470+
if (!reduc.empty()) {
471+
assert(reduc.size() == forOp.getNumResults());
472+
rewriter.setInsertionPointToEnd(forOp.getBody());
473+
rewriter.create<scf::YieldOp>(loc, reduc);
474+
}
475+
// Exit the loop.
476+
rewriter.setInsertionPointAfter(forOp);
477+
// In-place update reduction variables.
478+
for (unsigned i = 0, e = forOp.getResults().size(); i < e; i++)
479+
reduc[i] = forOp.getResult(i);
480+
} else {
481+
auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop);
482+
if (!reduc.empty()) {
483+
assert(reduc.size() == parOp.getInitVals().size() && reduc.size() == 1);
484+
Operation *redExp = reduc.front().getDefiningOp();
485+
// Reduction expression should have no use.
486+
assert(redExp->getUses().empty());
487+
// This must be a binary operation.
488+
// NOTE: This is users' responsibilty to ensure the operation are
489+
// commutative.
490+
assert(redExp->getNumOperands() == 2 && redExp->getNumResults() == 1);
491+
492+
Value redVal = parOp.getInitVals().front();
493+
Value curVal;
494+
if (redExp->getOperand(0) == redVal)
495+
curVal = redExp->getOperand(1);
496+
else if (redExp->getOperand(1) == redVal)
497+
curVal = redExp->getOperand(0);
498+
// One of the operands must be the init value (which is also the
499+
// previous reduction value).
500+
assert(curVal);
501+
// The reduction expression should be the only user of the reduction val
502+
// inside the parallel for.
503+
unsigned numUsers = 0;
504+
for (Operation *op : redVal.getUsers()) {
505+
if (op->getParentOp() == parOp)
506+
numUsers++;
507+
}
508+
assert(numUsers == 1);
509+
(void)numUsers; // to silence unused variable warning in release build
510+
511+
rewriter.setInsertionPointAfter(redExp);
512+
auto redOp = rewriter.create<scf::ReduceOp>(loc, curVal);
513+
// Attach to the reduction op.
514+
Block *redBlock = &redOp.getRegion().getBlocks().front();
515+
rewriter.setInsertionPointToEnd(redBlock);
516+
Operation *newRed = rewriter.clone(*redExp);
517+
// Replaces arguments of the reduction expression by using the block
518+
// arguments from scf.reduce.
519+
rewriter.updateRootInPlace(
520+
newRed, [&]() { newRed->setOperands(redBlock->getArguments()); });
521+
// Erases the out-dated reduction expression.
522+
rewriter.eraseOp(redExp);
523+
rewriter.setInsertionPointToEnd(redBlock);
524+
rewriter.create<scf::ReduceReturnOp>(loc, newRed->getResult(0));
525+
}
526+
rewriter.setInsertionPointAfter(parOp);
527+
// In-place update reduction variables.
528+
for (unsigned i = 0, e = parOp.getResults().size(); i < e; i++)
529+
reduc[i] = parOp.getResult(i);
448530
}
449531

450532
// Finished iterating a tensor, clean up
@@ -458,14 +540,10 @@ SparseTensorLoopEmitter::exitForLoop(OpBuilder &builder, Location loc,
458540
if (!isDenseDLT(dimTypes[tid][dim]))
459541
highs[tid][dim] = Value();
460542
}
461-
// exit the loop
462-
builder.setInsertionPointAfter(forOp);
463-
return forOp.getResults();
464543
}
465544

466-
SmallVector<Value, 2>
467-
SparseTensorLoopEmitter::exitCoiterationLoop(OpBuilder &builder, Location loc,
468-
ArrayRef<Value> reduc) {
545+
void SparseTensorLoopEmitter::exitCoIterationLoop(
546+
OpBuilder &builder, Location loc, MutableArrayRef<Value> reduc) {
469547
auto whileOp = llvm::cast<scf::WhileOp>(loopStack.back().loop);
470548
auto &dims = loopStack.back().dims;
471549
auto &tids = loopStack.back().tids;
@@ -499,10 +577,10 @@ SparseTensorLoopEmitter::exitCoiterationLoop(OpBuilder &builder, Location loc,
499577
}
500578

501579
// Reduction value from users.
502-
SmallVector<Value, 2> ret;
503-
for (auto red : reduc) {
504-
operands.push_back(red);
505-
ret.push_back(whileOp->getResult(o++));
580+
for (unsigned i = 0, e = reduc.size(); i < e; i++) {
581+
operands.push_back(reduc[i]);
582+
// In place update reduction variable.
583+
reduc[i] = whileOp->getResult(o++);
506584
}
507585

508586
// An (optional) universal index.
@@ -517,26 +595,24 @@ SparseTensorLoopEmitter::exitCoiterationLoop(OpBuilder &builder, Location loc,
517595
assert(o == operands.size());
518596
builder.create<scf::YieldOp>(loc, operands);
519597
builder.setInsertionPointAfter(whileOp);
520-
return ret;
521598
}
522599

523-
SmallVector<Value, 2>
524-
SparseTensorLoopEmitter::exitCurrentLoop(OpBuilder &builder, Location loc,
525-
ArrayRef<Value> reduc) {
600+
void SparseTensorLoopEmitter::exitCurrentLoop(RewriterBase &rewriter,
601+
Location loc,
602+
MutableArrayRef<Value> reduc) {
526603
// Clean up the values, it would help use to discover potential bug at a
527604
// earlier stage (instead of silently using a wrong value).
528605
LoopLevelInfo &loopInfo = loopStack.back();
529606
assert(loopInfo.tids.size() == loopInfo.dims.size());
530607
SmallVector<Value, 2> red;
531608
if (llvm::isa<scf::WhileOp>(loopInfo.loop)) {
532-
red = exitCoiterationLoop(builder, loc, reduc);
609+
exitCoIterationLoop(rewriter, loc, reduc);
533610
} else {
534-
red = exitForLoop(builder, loc, reduc);
611+
exitForLoop(rewriter, loc, reduc);
535612
}
536613

537614
assert(loopStack.size() == loopSeqStack.size());
538615
loopStack.pop_back();
539-
return red;
540616
}
541617

542618
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ class SparseTensorLoopEmitter {
380380
ArrayRef<size_t> dims, bool needsUniv, MutableArrayRef<Value> reduc = {},
381381
ArrayRef<size_t> extraTids = {}, ArrayRef<size_t> extraDims = {});
382382

383-
SmallVector<Value, 2> exitCurrentLoop(OpBuilder &builder, Location loc,
384-
ArrayRef<Value> reduc = {});
383+
void exitCurrentLoop(RewriterBase &rewriter, Location loc,
384+
MutableArrayRef<Value> reduc = {});
385385

386386
/// Returns the array of coordinate for all the loop generated till now.
387387
void getCoordinateArray(SmallVectorImpl<Value> &coords) const {
@@ -452,17 +452,35 @@ class SparseTensorLoopEmitter {
452452
ArrayRef<size_t> dims);
453453

454454
/// Exits a for loop, returns the reduction results, e.g.,
455+
/// For sequential for loops:
455456
/// %ret = for () {
456457
/// ...
458+
/// %val = addi %args, %c
457459
/// yield %val
458460
/// }
459-
/// Return %ret to user, while %val is provided by users (`reduc`)
460-
SmallVector<Value, 2> exitForLoop(OpBuilder &builder, Location loc,
461-
ArrayRef<Value> reduc);
461+
/// For parallel loops, the following generated code by users:
462+
/// %ret = parallel () init(%args) {
463+
/// ...
464+
/// %val = op %args, %c
465+
/// }
466+
/// will be transformed into
467+
/// %ret = parallel () init(%args) {
468+
/// ...
469+
/// scf.reduce(%c) bb0(%0, %1){
470+
/// %val = op %0, %1
471+
/// scf.reduce.return %val
472+
/// }
473+
/// }
474+
/// NOTE: only one instruction will be moved into reduce block, transformation
475+
/// will fail if multiple instructions are used to compute the reduction
476+
/// value.
477+
/// Return %ret to user, while %val is provided by users (`reduc`).
478+
void exitForLoop(RewriterBase &rewriter, Location loc,
479+
MutableArrayRef<Value> reduc);
462480

463481
/// Exits a while loop, returns the reduction results.
464-
SmallVector<Value, 2> exitCoiterationLoop(OpBuilder &builder, Location loc,
465-
ArrayRef<Value> reduc);
482+
void exitCoIterationLoop(OpBuilder &builder, Location loc,
483+
MutableArrayRef<Value> reduc);
466484

467485
// Whether the loop emitter needs to treat the last tensor as the output
468486
// tensor.

0 commit comments

Comments
 (0)