@@ -356,8 +356,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
356
356
RankedTensorType cooTp = getUnorderedCOOFromType (dstTp);
357
357
auto cooBuffer =
358
358
rewriter.create <AllocTensorOp>(loc, cooTp, dstDynSizes).getResult ();
359
- rewriter.create <ForeachOp>(
360
- loc, srcTensor, llvm::None ,
359
+ ForeachOp foreachOp = rewriter.create <ForeachOp>(
360
+ loc, srcTensor, cooBuffer ,
361
361
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
362
362
ValueRange reduc) {
363
363
SmallVector<Value, 4 > srcIndices;
@@ -368,11 +368,11 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
368
368
}
369
369
translateIndicesArray (builder, loc, op.getReassociationIndices (),
370
370
srcIndices, srcSizes, dstSizes, dstIndices);
371
- builder.create <InsertOp>(loc, v, cooBuffer , dstIndices);
372
- builder.create <sparse_tensor::YieldOp>(loc);
371
+ auto t = builder.create <InsertOp>(loc, v, reduc. front () , dstIndices);
372
+ builder.create <sparse_tensor::YieldOp>(loc, t );
373
373
});
374
-
375
- rewriter.replaceOpWithNewOp <ConvertOp>(op, dstTp, cooBuffer );
374
+ auto t = rewriter. create <LoadOp>(loc, foreachOp. getResult ( 0 ), true );
375
+ rewriter.replaceOpWithNewOp <ConvertOp>(op, dstTp, t );
376
376
return success ();
377
377
}
378
378
};
@@ -442,13 +442,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
442
442
rewriter.create <AllocTensorOp>(loc, cooTp, ValueRange ()).getResult ();
443
443
444
444
Value offset = constantIndex (rewriter, loc, 0 );
445
+ ForeachOp foreachOp;
445
446
for (Value input : op.getInputs ()) {
446
447
// Builds the indexing map.
447
448
448
449
// Build a for op for each input tensor to append new values into the
449
450
// output tensor.
450
- rewriter.create <ForeachOp>(
451
- loc, input, llvm::None ,
451
+ foreachOp = rewriter.create <ForeachOp>(
452
+ loc, input, cooBuffer ,
452
453
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
453
454
ValueRange reduc) {
454
455
SmallVector<Value, 4 > indices;
@@ -461,8 +462,8 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
461
462
idx = builder.create <arith::AddIOp>(loc, idx, offset);
462
463
indices.push_back (idx);
463
464
}
464
- builder.create <InsertOp>(loc, v, cooBuffer , indices);
465
- builder.create <sparse_tensor::YieldOp>(loc);
465
+ auto t = builder.create <InsertOp>(loc, v, reduc. front () , indices);
466
+ builder.create <sparse_tensor::YieldOp>(loc, t );
466
467
});
467
468
// Accumulates the offset. Note that only static-shaped inputs are allowed
468
469
// by concatenate op verifier, which saves us from computing the offset
@@ -471,7 +472,10 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
471
472
assert (!ShapedType::isDynamic (d));
472
473
offset = rewriter.create <arith::AddIOp>(loc, offset,
473
474
constantIndex (rewriter, loc, d));
475
+ cooBuffer = foreachOp.getResult (0 );
474
476
}
477
+
478
+ cooBuffer = rewriter.create <LoadOp>(loc, cooBuffer, true );
475
479
rewriter.replaceOpWithNewOp <ConvertOp>(op, rtp, cooBuffer);
476
480
return success ();
477
481
}
@@ -602,19 +606,19 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
602
606
srcTp = getUnorderedCOOFromType (srcTp);
603
607
tmpCoo =
604
608
rewriter.create <AllocTensorOp>(loc, srcTp, dynSrcSizes).getResult ();
605
- rewriter.create <ForeachOp>(
606
- loc, src, llvm::None ,
609
+ auto foreachOp = rewriter.create <ForeachOp>(
610
+ loc, src, tmpCoo ,
607
611
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
608
612
ValueRange reduc) {
609
613
SmallVector<Value, 4 > indices;
610
614
for (int64_t i = 0 , e = srcTp.getRank (); i < e; i++) {
611
615
uint64_t dim = toStoredDim (encSrc, i);
612
616
indices.push_back (args[dim]);
613
617
}
614
- builder.create <InsertOp>(loc, v, tmpCoo , indices);
615
- builder.create <sparse_tensor::YieldOp>(loc);
618
+ auto t = builder.create <InsertOp>(loc, v, reduc. front () , indices);
619
+ builder.create <sparse_tensor::YieldOp>(loc, t );
616
620
});
617
- src = tmpCoo ;
621
+ src = rewriter. create <LoadOp>(loc, foreachOp. getResult ( 0 ), true ) ;
618
622
}
619
623
620
624
// Sort the COO tensor so that its elements are ordered via increasing
@@ -653,29 +657,31 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
653
657
getDynamicSizes (dstTp, srcSizes, dynDstSizes);
654
658
Value dst =
655
659
rewriter.create <AllocTensorOp>(loc, dstTp, dynDstSizes).getResult ();
656
- rewriter.create <ForeachOp>(loc, src, llvm::None,
657
- [&](OpBuilder &builder, Location loc ,
658
- ValueRange args, Value v, ValueRange reduc) {
659
- SmallVector<Value, 4 > indices;
660
- for ( int64_t i = 0 , e = srcTp. getRank (); i < e ;
661
- i++) {
662
- uint64_t dim = toStoredDim (encDst, i);
663
- indices.push_back (args[dim]);
664
- }
665
- builder.create <InsertOp>(loc, v, dst , indices);
666
- builder.create <sparse_tensor::YieldOp>(loc);
667
- });
660
+ auto foreachOp = rewriter.create <ForeachOp>(
661
+ loc, src, dst ,
662
+ [&](OpBuilder &builder, Location loc, ValueRange args, Value v,
663
+ ValueRange reduc) {
664
+ SmallVector<Value, 4 > indices ;
665
+ for ( int64_t i = 0 , e = srcTp. getRank (); i < e; i++) {
666
+ uint64_t dim = toStoredDim (encDst, i);
667
+ indices.push_back (args[dim]);
668
+ }
669
+ auto t = builder.create <InsertOp>(loc, v, reduc. front () , indices);
670
+ builder.create <sparse_tensor::YieldOp>(loc, t );
671
+ });
668
672
669
- // Release the temporary COO if it is created.
673
+ // Release the temporary COO if it is created. Note that tmpCoo is
674
+ // invalidated due to foreach and updated to src.
670
675
if (tmpCoo)
671
- rewriter.create <DeallocTensorOp>(loc, tmpCoo );
676
+ rewriter.create <DeallocTensorOp>(loc, src );
672
677
673
678
// Directly replace op with dst results in bufferization error message
674
679
// "sparse tensor allocation should not escape function".
675
680
// As such, we insert a trivial tensor convert which will be removed by
676
681
// codegen.
677
682
rewriter.setInsertionPointAfter (op);
678
- rewriter.replaceOpWithNewOp <ConvertOp>(op, dstTp, dst);
683
+ auto t = rewriter.create <LoadOp>(loc, foreachOp.getResult (0 ), true );
684
+ rewriter.replaceOpWithNewOp <ConvertOp>(op, dstTp, t);
679
685
return success ();
680
686
}
681
687
};
@@ -694,14 +700,18 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
694
700
int64_t rank = rtp.getRank ();
695
701
auto enc = getSparseTensorEncoding (rtp);
696
702
703
+ SmallVector<Value> reduc = op.getInitArgs ();
704
+
697
705
// 1. Generates loop for the sparse input.
698
706
SparseTensorLoopEmitter loopEmitter (ValueRange{input});
699
707
loopEmitter.initializeLoopEmit (rewriter, loc);
700
708
for (int64_t i = 0 ; i < rank; i++) {
701
709
// TODO: provide utility function for loop sequences that only contains
702
710
// one for loop?
703
711
loopEmitter.enterNewLoopSeq (rewriter, loc, 0 , static_cast <size_t >(i));
704
- loopEmitter.enterLoopOverTensorAtDim (rewriter, loc, 0 , i);
712
+ // Note that reduc will be taken care of by loop emitter and get updated
713
+ // in place.
714
+ loopEmitter.enterLoopOverTensorAtDim (rewriter, loc, 0 , i, reduc);
705
715
}
706
716
707
717
SmallVector<Value, 4 > coords;
@@ -716,15 +726,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
716
726
: rewriter.create <memref::LoadOp>(loc, vals, coords);
717
727
718
728
// 2. Inline the block in the foreach operator.
719
- Block::iterator inlinePos = rewriter.getInsertionPoint ();
720
729
Block *srcBlock = op.getBody ();
721
- // Remove sparse_tensor.yield.
722
- rewriter.eraseOp (srcBlock->getTerminator ());
723
-
724
- for (int64_t i = 0 ; i < rank; i++) {
725
- loopEmitter.exitCurrentLoop (rewriter, loc);
726
- loopEmitter.exitCurrentLoopSeq ();
727
- }
728
730
729
731
SmallVector<Value, 4 > args;
730
732
// Remap coordinates.
@@ -734,11 +736,33 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
734
736
}
735
737
// Remap value.
736
738
args.push_back (val);
739
+ // Remap reduction variables.
740
+ args.append (reduc);
741
+
742
+ // Remove sparse_tensor.yield.
743
+ SmallVector<Value> reducValue = srcBlock->getTerminator ()->getOperands ();
744
+ rewriter.eraseOp (srcBlock->getTerminator ());
737
745
738
746
// Inline body.
739
- rewriter.mergeBlockBefore (srcBlock, &*inlinePos, args);
740
- // delete the foreach operator.
741
- rewriter.eraseOp (op);
747
+ if (!reducValue.empty ()) {
748
+ rewriter.mergeBlocks (srcBlock, rewriter.getBlock (), args);
749
+ } else {
750
+ // This is annoying, since scf.for inserts a implicit yield op when
751
+ // there is no reduction variable upon creation, in this case we need to
752
+ // merge the block *before* the yield op.
753
+ rewriter.mergeBlockBefore (srcBlock, &*rewriter.getInsertionPoint (), args);
754
+ }
755
+
756
+ for (int64_t i = 0 ; i < rank; i++) {
757
+ // Link the reduction chain. Note that loop emitter update the reducValue
758
+ // in place.
759
+ loopEmitter.exitCurrentLoop (rewriter, loc, reducValue);
760
+ loopEmitter.exitCurrentLoopSeq ();
761
+ }
762
+
763
+ // Replace the foreach operator with the value returned by the outtermost
764
+ // for loop.
765
+ rewriter.replaceOp (op, reducValue);
742
766
return success ();
743
767
}
744
768
};
@@ -801,7 +825,8 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
801
825
.getResult (0 );
802
826
Type eltTp = dstTp.getElementType ();
803
827
Value value = genAllocaScalar (rewriter, loc, eltTp);
804
- scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, c0, nnz, c1);
828
+ scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, c0, nnz, c1,
829
+ ArrayRef<Value>(cooBuffer));
805
830
rewriter.setInsertionPointToStart (forOp.getBody ());
806
831
807
832
SmallString<18 > getNextFuncName{" getSparseTensorReaderNext" ,
@@ -816,13 +841,17 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
816
841
loc, indices, constantIndex (rewriter, loc, i)));
817
842
}
818
843
Value v = rewriter.create <memref::LoadOp>(loc, value);
819
- rewriter.create <InsertOp>(loc, v, cooBuffer, indicesArray);
844
+ auto t = rewriter.create <InsertOp>(loc, v, forOp.getRegionIterArg (0 ),
845
+ indicesArray);
846
+ rewriter.create <scf::YieldOp>(loc, ArrayRef<Value>(t));
820
847
rewriter.setInsertionPointAfter (forOp);
848
+ // Link SSA chain.
849
+ cooBuffer = forOp.getResult (0 );
821
850
822
851
// Release the sparse tensor reader.
823
852
createFuncCall (rewriter, loc, " delSparseTensorReader" , {}, {reader},
824
853
EmitCInterface::Off);
825
-
854
+ cooBuffer = rewriter. create <LoadOp>(loc, cooBuffer, true );
826
855
Value newOp = rewriter.replaceOpWithNewOp <ConvertOp>(op, dstTp, cooBuffer);
827
856
828
857
// Release the unordered COO tensor buffer.
0 commit comments