@@ -600,8 +600,8 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
600
600
flattenOperands (adaptor.getOperands (), flattened);
601
601
auto newCall = rewriter.create <func::CallOp>(loc, op.getCallee (),
602
602
finalRetTy, flattened);
603
- // (2) Create cast operation for sparse tensor returns.
604
- SmallVector<Value> castedRet ;
603
+ // (2) Gather sparse tensor returns.
604
+ SmallVector<SmallVector< Value>> packedResultVals ;
605
605
// Tracks the offset of current return value (of the original call)
606
606
// relative to the new call (after sparse tensor flattening);
607
607
unsigned retOffset = 0 ;
@@ -618,21 +618,22 @@ class SparseCallConverter : public OpConversionPattern<func::CallOp> {
618
618
assert (!sparseFlat.empty ());
619
619
if (sparseFlat.size () > 1 ) {
620
620
auto flatSize = sparseFlat.size ();
621
- ValueRange fields (iterator_range<ResultRange::iterator>(
622
- newCall.result_begin () + retOffset,
623
- newCall.result_begin () + retOffset + flatSize));
624
- castedRet.push_back (genTuple (rewriter, loc, retType, fields));
621
+ packedResultVals.push_back (SmallVector<Value>());
622
+ llvm::append_range (packedResultVals.back (),
623
+ newCall.getResults ().slice (retOffset, flatSize));
625
624
retOffset += flatSize;
626
625
} else {
627
626
// If this is an 1:1 conversion, no need for casting.
628
- castedRet.push_back (newCall.getResult (retOffset));
627
+ packedResultVals.emplace_back ();
628
+ packedResultVals.back ().push_back (newCall.getResult (retOffset));
629
629
retOffset++;
630
630
}
631
631
sparseFlat.clear ();
632
632
}
633
633
634
- assert (castedRet.size () == op.getNumResults ());
635
- rewriter.replaceOp (op, castedRet);
634
+ assert (packedResultVals.size () == op.getNumResults ());
635
+ rewriter.replaceOpWithMultiple (
636
+ op, llvm::to_vector_of<ValueRange>(packedResultVals));
636
637
return success ();
637
638
}
638
639
};
@@ -776,7 +777,7 @@ class SparseTensorAllocConverter
776
777
// Reuses specifier.
777
778
fields.push_back (desc.getSpecifier ());
778
779
assert (fields.size () == desc.getNumFields ());
779
- rewriter.replaceOp (op, genTuple (rewriter, loc, resType, fields) );
780
+ rewriter.replaceOpWithMultiple (op, { fields} );
780
781
return success ();
781
782
}
782
783
@@ -796,7 +797,7 @@ class SparseTensorAllocConverter
796
797
sizeHint, lvlSizesValues, fields);
797
798
798
799
// Replace operation with resulting memrefs.
799
- rewriter.replaceOp (op, genTuple (rewriter, loc, resType, fields) );
800
+ rewriter.replaceOpWithMultiple (op, { fields} );
800
801
return success ();
801
802
}
802
803
@@ -837,7 +838,7 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
837
838
sizeHint, lvlSizesValues, fields);
838
839
839
840
// Replace operation with resulting memrefs.
840
- rewriter.replaceOp (op, genTuple (rewriter, loc, resType, fields) );
841
+ rewriter.replaceOpWithMultiple (op, { fields} );
841
842
return success ();
842
843
}
843
844
@@ -893,7 +894,7 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
893
894
if (op.getHasInserts ())
894
895
genEndInsert (rewriter, op.getLoc (), desc);
895
896
// Replace operation with resulting memrefs.
896
- rewriter.replaceOp (op, genTuple (rewriter, op. getLoc (), desc) );
897
+ rewriter.replaceOpWithMultiple (op, {desc. getFields ()} );
897
898
return success ();
898
899
}
899
900
};
@@ -1006,15 +1007,14 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
1006
1007
rewriter.create <scf::YieldOp>(loc, insertRet);
1007
1008
1008
1009
rewriter.setInsertionPointAfter (loop);
1009
- Value result = genTuple (rewriter, loc, dstType, loop->getResults ());
1010
1010
// Deallocate the buffers on exit of the full loop nest.
1011
1011
Operation *parent = getTop (op);
1012
1012
rewriter.setInsertionPointAfter (parent);
1013
1013
rewriter.create <memref::DeallocOp>(loc, values);
1014
1014
rewriter.create <memref::DeallocOp>(loc, filled);
1015
1015
rewriter.create <memref::DeallocOp>(loc, added);
1016
1016
// Replace operation with resulting memrefs.
1017
- rewriter.replaceOp (op, result );
1017
+ rewriter.replaceOpWithMultiple (op, {loop-> getResults ()} );
1018
1018
return success ();
1019
1019
}
1020
1020
};
@@ -1041,8 +1041,7 @@ class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
1041
1041
params, /* genCall=*/ true );
1042
1042
SmallVector<Value> ret = insertGen.genCallOrInline (rewriter, loc);
1043
1043
// Replace operation with resulting memrefs.
1044
- rewriter.replaceOp (op,
1045
- genTuple (rewriter, loc, op.getDest ().getType (), ret));
1044
+ rewriter.replaceOpWithMultiple (op, {ret});
1046
1045
return success ();
1047
1046
}
1048
1047
};
@@ -1215,8 +1214,7 @@ class SparseConvertConverter : public OpConversionPattern<ConvertOp> {
1215
1214
return true ;
1216
1215
});
1217
1216
1218
- rewriter.replaceOp (
1219
- op, genTuple (rewriter, loc, op.getResult ().getType (), fields));
1217
+ rewriter.replaceOpWithMultiple (op, {fields});
1220
1218
return success ();
1221
1219
}
1222
1220
};
@@ -1271,8 +1269,7 @@ class SparseExtractSliceConverter
1271
1269
// NOTE: we can not generate tuples directly from descriptor here, as the
1272
1270
// descriptor is holding the original type, yet we want the slice type
1273
1271
// here (they shared every memref but with an updated specifier).
1274
- rewriter.replaceOp (op, genTuple (rewriter, loc, op.getResult ().getType (),
1275
- desc.getFields ()));
1272
+ rewriter.replaceOpWithMultiple (op, {desc.getFields ()});
1276
1273
return success ();
1277
1274
}
1278
1275
};
@@ -1403,7 +1400,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
1403
1400
}
1404
1401
desc.setValMemSize (rewriter, loc, memSize);
1405
1402
1406
- rewriter.replaceOp (op, genTuple (rewriter, loc, desc) );
1403
+ rewriter.replaceOpWithMultiple (op, { desc. getFields ()} );
1407
1404
return success ();
1408
1405
}
1409
1406
};
@@ -1577,7 +1574,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
1577
1574
EmitCInterface::Off);
1578
1575
1579
1576
// Replace operation with resulting memrefs.
1580
- rewriter.replaceOp (op, genTuple (rewriter, loc, dstTp, fields) );
1577
+ rewriter.replaceOpWithMultiple (op, { fields} );
1581
1578
return success ();
1582
1579
}
1583
1580
};
0 commit comments