@@ -219,9 +219,12 @@ Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
219
219
OpBuilder &builder, Location loc, size_t tid, size_t dim,
220
220
MutableArrayRef<Value> reduc, bool isParallel, ArrayRef<size_t > extraTids,
221
221
ArrayRef<size_t > extraDims) {
222
+
222
223
assert (dimTypes[tid].size () > dim);
223
224
// We can not re-enter the same level.
224
225
assert (!coord[tid][dim]);
226
+ // TODO: support multiple return on parallel for?
227
+ assert (!isParallel || reduc.empty () <= 1 );
225
228
226
229
Value step = constantIndex (builder, loc, 1 );
227
230
auto dimType = dimTypes[tid][dim];
@@ -232,11 +235,38 @@ Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
232
235
Value lo = isSparseInput ? pidxs[tid][dim] // current offset
233
236
: loopSeqStack.back (); // univeral tid
234
237
Value hi = highs[tid][dim];
238
+ Operation *loop = nullptr ;
239
+ Value iv;
240
+ if (isParallel) {
241
+ scf::ParallelOp parOp =
242
+ builder.create <scf::ParallelOp>(loc, lo, hi, step, reduc);
243
+ builder.setInsertionPointToStart (parOp.getBody ());
244
+ assert (parOp.getNumReductions () == reduc.size ());
245
+ iv = parOp.getInductionVars ()[0 ];
246
+
247
+ // In-place update on the reduction variable vector.
248
+ // Note that the init vals is not the actual reduction variables but instead
249
+ // used as a `special handle` to (temporarily) represent them. The
250
+ // expression on init vals will be moved into scf.reduce and replaced with
251
+ // the block arguments when exiting the loop (see exitForLoop). This is
252
+ // needed as we can not build the actual reduction block and get the actual
253
+ // reduction varaible before users fill parallel loop body.
254
+ for (int i = 0 , e = reduc.size (); i < e; i++)
255
+ reduc[i] = parOp.getInitVals ()[i];
256
+ loop = parOp;
257
+ } else {
258
+ scf::ForOp forOp = builder.create <scf::ForOp>(loc, lo, hi, step, reduc);
259
+ builder.setInsertionPointToStart (forOp.getBody ());
260
+ iv = forOp.getInductionVar ();
261
+
262
+ // In-place update on the reduction variable vector.
263
+ assert (forOp.getNumRegionIterArgs () == reduc.size ());
264
+ for (int i = 0 , e = reduc.size (); i < e; i++)
265
+ reduc[i] = forOp.getRegionIterArg (i);
266
+ loop = forOp;
267
+ }
268
+ assert (loop && iv);
235
269
236
- scf::ForOp forOp = builder.create <scf::ForOp>(loc, lo, hi, step, reduc);
237
- builder.setInsertionPointToStart (forOp.getBody ());
238
- Value iv = forOp.getInductionVar ();
239
- assert (iv);
240
270
if (isSparseInput) {
241
271
pidxs[tid][dim] = iv;
242
272
// Generating a load on the indices array yields the coordinate.
@@ -253,16 +283,12 @@ Operation *SparseTensorLoopEmitter::enterLoopOverTensorAtDim(
253
283
254
284
// NOTE: we can also prepares for next dim here in advance
255
285
// Push the loop into stack
256
- loopStack.emplace_back (ArrayRef<size_t >(tid), ArrayRef<size_t >(dim), forOp ,
286
+ loopStack.emplace_back (ArrayRef<size_t >(tid), ArrayRef<size_t >(dim), loop ,
257
287
coord[tid][dim]);
258
288
// Emit extra locals.
259
289
emitExtraLocalsForTensorsAtDenseDims (builder, loc, extraTids, extraDims);
260
290
261
- // In-place update on the reduction variable vector.
262
- assert (forOp.getNumRegionIterArgs () == reduc.size ());
263
- for (int i = 0 , e = reduc.size (); i < e; i++)
264
- reduc[i] = forOp.getRegionIterArg (i);
265
- return forOp;
291
+ return loop;
266
292
}
267
293
268
294
Operation *SparseTensorLoopEmitter::enterCoIterationOverTensorsAtDims (
@@ -434,17 +460,73 @@ void SparseTensorLoopEmitter::emitExtraLocalsForTensorsAtDenseDims(
434
460
}
435
461
}
436
462
437
- SmallVector<Value, 2 >
438
- SparseTensorLoopEmitter::exitForLoop (OpBuilder &builder, Location loc,
439
- ArrayRef<Value> reduc) {
463
+ void SparseTensorLoopEmitter::exitForLoop (RewriterBase &rewriter, Location loc,
464
+ MutableArrayRef<Value> reduc) {
440
465
LoopLevelInfo &loopInfo = loopStack.back ();
441
466
auto &dims = loopStack.back ().dims ;
442
467
auto &tids = loopStack.back ().tids ;
443
- auto forOp = llvm::cast<scf::ForOp>(loopInfo.loop );
444
- if (!reduc.empty ()) {
445
- assert (reduc.size () == forOp.getNumResults ());
446
- builder.setInsertionPointToEnd (forOp.getBody ());
447
- builder.create <scf::YieldOp>(loc, reduc);
468
+ auto forOp = llvm::dyn_cast<scf::ForOp>(loopInfo.loop );
469
+ if (forOp) {
470
+ if (!reduc.empty ()) {
471
+ assert (reduc.size () == forOp.getNumResults ());
472
+ rewriter.setInsertionPointToEnd (forOp.getBody ());
473
+ rewriter.create <scf::YieldOp>(loc, reduc);
474
+ }
475
+ // Exit the loop.
476
+ rewriter.setInsertionPointAfter (forOp);
477
+ // In-place update reduction variables.
478
+ for (unsigned i = 0 , e = forOp.getResults ().size (); i < e; i++)
479
+ reduc[i] = forOp.getResult (i);
480
+ } else {
481
+ auto parOp = llvm::cast<scf::ParallelOp>(loopInfo.loop );
482
+ if (!reduc.empty ()) {
483
+ assert (reduc.size () == parOp.getInitVals ().size () && reduc.size () == 1 );
484
+ Operation *redExp = reduc.front ().getDefiningOp ();
485
+ // Reduction expression should have no use.
486
+ assert (redExp->getUses ().empty ());
487
+ // This must be a binary operation.
488
+ // NOTE: This is users' responsibilty to ensure the operation are
489
+ // commutative.
490
+ assert (redExp->getNumOperands () == 2 && redExp->getNumResults () == 1 );
491
+
492
+ Value redVal = parOp.getInitVals ().front ();
493
+ Value curVal;
494
+ if (redExp->getOperand (0 ) == redVal)
495
+ curVal = redExp->getOperand (1 );
496
+ else if (redExp->getOperand (1 ) == redVal)
497
+ curVal = redExp->getOperand (0 );
498
+ // One of the operands must be the init value (which is also the
499
+ // previous reduction value).
500
+ assert (curVal);
501
+ // The reduction expression should be the only user of the reduction val
502
+ // inside the parallel for.
503
+ unsigned numUsers = 0 ;
504
+ for (Operation *op : redVal.getUsers ()) {
505
+ if (op->getParentOp () == parOp)
506
+ numUsers++;
507
+ }
508
+ assert (numUsers == 1 );
509
+ (void )numUsers; // to silence unused variable warning in release build
510
+
511
+ rewriter.setInsertionPointAfter (redExp);
512
+ auto redOp = rewriter.create <scf::ReduceOp>(loc, curVal);
513
+ // Attach to the reduction op.
514
+ Block *redBlock = &redOp.getRegion ().getBlocks ().front ();
515
+ rewriter.setInsertionPointToEnd (redBlock);
516
+ Operation *newRed = rewriter.clone (*redExp);
517
+ // Replaces arguments of the reduction expression by using the block
518
+ // arguments from scf.reduce.
519
+ rewriter.updateRootInPlace (
520
+ newRed, [&]() { newRed->setOperands (redBlock->getArguments ()); });
521
+ // Erases the out-dated reduction expression.
522
+ rewriter.eraseOp (redExp);
523
+ rewriter.setInsertionPointToEnd (redBlock);
524
+ rewriter.create <scf::ReduceReturnOp>(loc, newRed->getResult (0 ));
525
+ }
526
+ rewriter.setInsertionPointAfter (parOp);
527
+ // In-place update reduction variables.
528
+ for (unsigned i = 0 , e = parOp.getResults ().size (); i < e; i++)
529
+ reduc[i] = parOp.getResult (i);
448
530
}
449
531
450
532
// Finished iterating a tensor, clean up
@@ -458,14 +540,10 @@ SparseTensorLoopEmitter::exitForLoop(OpBuilder &builder, Location loc,
458
540
if (!isDenseDLT (dimTypes[tid][dim]))
459
541
highs[tid][dim] = Value ();
460
542
}
461
- // exit the loop
462
- builder.setInsertionPointAfter (forOp);
463
- return forOp.getResults ();
464
543
}
465
544
466
- SmallVector<Value, 2 >
467
- SparseTensorLoopEmitter::exitCoiterationLoop (OpBuilder &builder, Location loc,
468
- ArrayRef<Value> reduc) {
545
+ void SparseTensorLoopEmitter::exitCoIterationLoop (
546
+ OpBuilder &builder, Location loc, MutableArrayRef<Value> reduc) {
469
547
auto whileOp = llvm::cast<scf::WhileOp>(loopStack.back ().loop );
470
548
auto &dims = loopStack.back ().dims ;
471
549
auto &tids = loopStack.back ().tids ;
@@ -499,10 +577,10 @@ SparseTensorLoopEmitter::exitCoiterationLoop(OpBuilder &builder, Location loc,
499
577
}
500
578
501
579
// Reduction value from users.
502
- SmallVector<Value, 2 > ret;
503
- for ( auto red : reduc) {
504
- operands. push_back (red);
505
- ret. push_back ( whileOp->getResult (o++) );
580
+ for ( unsigned i = 0 , e = reduc. size (); i < e; i++) {
581
+ operands. push_back ( reduc[i]);
582
+ // In place update reduction variable.
583
+ reduc[i] = whileOp->getResult (o++);
506
584
}
507
585
508
586
// An (optional) universal index.
@@ -517,26 +595,24 @@ SparseTensorLoopEmitter::exitCoiterationLoop(OpBuilder &builder, Location loc,
517
595
assert (o == operands.size ());
518
596
builder.create <scf::YieldOp>(loc, operands);
519
597
builder.setInsertionPointAfter (whileOp);
520
- return ret;
521
598
}
522
599
523
- SmallVector<Value, 2 >
524
- SparseTensorLoopEmitter::exitCurrentLoop (OpBuilder &builder, Location loc,
525
- ArrayRef <Value> reduc) {
600
+ void SparseTensorLoopEmitter::exitCurrentLoop (RewriterBase &rewriter,
601
+ Location loc,
602
+ MutableArrayRef <Value> reduc) {
526
603
// Clean up the values, it would help use to discover potential bug at a
527
604
// earlier stage (instead of silently using a wrong value).
528
605
LoopLevelInfo &loopInfo = loopStack.back ();
529
606
assert (loopInfo.tids .size () == loopInfo.dims .size ());
530
607
SmallVector<Value, 2 > red;
531
608
if (llvm::isa<scf::WhileOp>(loopInfo.loop )) {
532
- red = exitCoiterationLoop (builder , loc, reduc);
609
+ exitCoIterationLoop (rewriter , loc, reduc);
533
610
} else {
534
- red = exitForLoop (builder , loc, reduc);
611
+ exitForLoop (rewriter , loc, reduc);
535
612
}
536
613
537
614
assert (loopStack.size () == loopSeqStack.size ());
538
615
loopStack.pop_back ();
539
- return red;
540
616
}
541
617
542
618
// ===----------------------------------------------------------------------===//
0 commit comments