Skip to content

Commit 70508b6

Browse files
author
Peiming Liu
committed
[mlir][sparse] fix sparse tensor rewriting patterns that do not propagate sparse tensor SSA properly.
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D137468
1 parent 7ded25c commit 70508b6

File tree

8 files changed

+181
-120
lines changed

8 files changed

+181
-120
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,12 @@ void ForeachOp::build(
603603
std::fill_n(std::back_inserter(blockArgTypes), rank, builder.getIndexType());
604604
// Followed by one value.
605605
blockArgTypes.push_back(rtp.getElementType());
606+
// Followed by reduction variable.
607+
blockArgTypes.append(initArgs.getTypes().begin(), initArgs.getTypes().end());
606608

607609
SmallVector<Location, 4> blockArgLocs;
608-
std::fill_n(std::back_inserter(blockArgLocs), rank + 1, tensor.getLoc());
610+
std::fill_n(std::back_inserter(blockArgLocs), blockArgTypes.size(),
611+
tensor.getLoc());
609612

610613
OpBuilder::InsertionGuard guard(builder);
611614
auto &region = *result.regions.front();

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,9 @@ Value mlir::sparse_tensor::genValueForDense(OpBuilder &builder, Location loc,
956956
return val;
957957
}
958958

959+
// FIXME:
960+
// 1. Dense tensors loop should be generated by loop emitter.
961+
// 2. Support reduction variables to propagate SSA chains properly.
959962
void mlir::sparse_tensor::genDenseTensorOrSparseConstantIterLoop(
960963
OpBuilder &builder, Location loc, Value src, unsigned rank,
961964
function_ref<void(OpBuilder &, Location, Value, ValueRange)> bodyBuilder) {

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 74 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
356356
RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
357357
auto cooBuffer =
358358
rewriter.create<AllocTensorOp>(loc, cooTp, dstDynSizes).getResult();
359-
rewriter.create<ForeachOp>(
360-
loc, srcTensor, llvm::None,
359+
ForeachOp foreachOp = rewriter.create<ForeachOp>(
360+
loc, srcTensor, cooBuffer,
361361
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
362362
ValueRange reduc) {
363363
SmallVector<Value, 4> srcIndices;
@@ -368,11 +368,11 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
368368
}
369369
translateIndicesArray(builder, loc, op.getReassociationIndices(),
370370
srcIndices, srcSizes, dstSizes, dstIndices);
371-
builder.create<InsertOp>(loc, v, cooBuffer, dstIndices);
372-
builder.create<sparse_tensor::YieldOp>(loc);
371+
auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstIndices);
372+
builder.create<sparse_tensor::YieldOp>(loc, t);
373373
});
374-
375-
rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, cooBuffer);
374+
auto t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
375+
rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, t);
376376
return success();
377377
}
378378
};
@@ -442,13 +442,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
442442
rewriter.create<AllocTensorOp>(loc, cooTp, ValueRange()).getResult();
443443

444444
Value offset = constantIndex(rewriter, loc, 0);
445+
ForeachOp foreachOp;
445446
for (Value input : op.getInputs()) {
446447
// Builds the indexing map.
447448

448449
// Build a for op for each input tensor to append new values into the
449450
// output tensor.
450-
rewriter.create<ForeachOp>(
451-
loc, input, llvm::None,
451+
foreachOp = rewriter.create<ForeachOp>(
452+
loc, input, cooBuffer,
452453
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
453454
ValueRange reduc) {
454455
SmallVector<Value, 4> indices;
@@ -461,8 +462,8 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
461462
idx = builder.create<arith::AddIOp>(loc, idx, offset);
462463
indices.push_back(idx);
463464
}
464-
builder.create<InsertOp>(loc, v, cooBuffer, indices);
465-
builder.create<sparse_tensor::YieldOp>(loc);
465+
auto t = builder.create<InsertOp>(loc, v, reduc.front(), indices);
466+
builder.create<sparse_tensor::YieldOp>(loc, t);
466467
});
467468
// Accumulates the offset. Note that only static-shaped inputs are allowed
468469
// by concatenate op verifier, which saves us from computing the offset
@@ -471,7 +472,10 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
471472
assert(!ShapedType::isDynamic(d));
472473
offset = rewriter.create<arith::AddIOp>(loc, offset,
473474
constantIndex(rewriter, loc, d));
475+
cooBuffer = foreachOp.getResult(0);
474476
}
477+
478+
cooBuffer = rewriter.create<LoadOp>(loc, cooBuffer, true);
475479
rewriter.replaceOpWithNewOp<ConvertOp>(op, rtp, cooBuffer);
476480
return success();
477481
}
@@ -602,19 +606,19 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
602606
srcTp = getUnorderedCOOFromType(srcTp);
603607
tmpCoo =
604608
rewriter.create<AllocTensorOp>(loc, srcTp, dynSrcSizes).getResult();
605-
rewriter.create<ForeachOp>(
606-
loc, src, llvm::None,
609+
auto foreachOp = rewriter.create<ForeachOp>(
610+
loc, src, tmpCoo,
607611
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
608612
ValueRange reduc) {
609613
SmallVector<Value, 4> indices;
610614
for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
611615
uint64_t dim = toStoredDim(encSrc, i);
612616
indices.push_back(args[dim]);
613617
}
614-
builder.create<InsertOp>(loc, v, tmpCoo, indices);
615-
builder.create<sparse_tensor::YieldOp>(loc);
618+
auto t = builder.create<InsertOp>(loc, v, reduc.front(), indices);
619+
builder.create<sparse_tensor::YieldOp>(loc, t);
616620
});
617-
src = tmpCoo;
621+
src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
618622
}
619623

620624
// Sort the COO tensor so that its elements are ordered via increasing
@@ -653,29 +657,31 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
653657
getDynamicSizes(dstTp, srcSizes, dynDstSizes);
654658
Value dst =
655659
rewriter.create<AllocTensorOp>(loc, dstTp, dynDstSizes).getResult();
656-
rewriter.create<ForeachOp>(loc, src, llvm::None,
657-
[&](OpBuilder &builder, Location loc,
658-
ValueRange args, Value v, ValueRange reduc) {
659-
SmallVector<Value, 4> indices;
660-
for (int64_t i = 0, e = srcTp.getRank(); i < e;
661-
i++) {
662-
uint64_t dim = toStoredDim(encDst, i);
663-
indices.push_back(args[dim]);
664-
}
665-
builder.create<InsertOp>(loc, v, dst, indices);
666-
builder.create<sparse_tensor::YieldOp>(loc);
667-
});
660+
auto foreachOp = rewriter.create<ForeachOp>(
661+
loc, src, dst,
662+
[&](OpBuilder &builder, Location loc, ValueRange args, Value v,
663+
ValueRange reduc) {
664+
SmallVector<Value, 4> indices;
665+
for (int64_t i = 0, e = srcTp.getRank(); i < e; i++) {
666+
uint64_t dim = toStoredDim(encDst, i);
667+
indices.push_back(args[dim]);
668+
}
669+
auto t = builder.create<InsertOp>(loc, v, reduc.front(), indices);
670+
builder.create<sparse_tensor::YieldOp>(loc, t);
671+
});
668672

669-
// Release the temporary COO if it is created.
673+
// Release the temporary COO if it is created. Note that tmpCoo is
674+
// invalidated due to foreach and updated to src.
670675
if (tmpCoo)
671-
rewriter.create<DeallocTensorOp>(loc, tmpCoo);
676+
rewriter.create<DeallocTensorOp>(loc, src);
672677

673678
// Directly replace op with dst results in bufferization error message
674679
// "sparse tensor allocation should not escape function".
675680
// As such, we insert a trivial tensor convert which will be removed by
676681
// codegen.
677682
rewriter.setInsertionPointAfter(op);
678-
rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, dst);
683+
auto t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
684+
rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, t);
679685
return success();
680686
}
681687
};
@@ -694,14 +700,18 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
694700
int64_t rank = rtp.getRank();
695701
auto enc = getSparseTensorEncoding(rtp);
696702

703+
SmallVector<Value> reduc = op.getInitArgs();
704+
697705
// 1. Generates loop for the sparse input.
698706
SparseTensorLoopEmitter loopEmitter(ValueRange{input});
699707
loopEmitter.initializeLoopEmit(rewriter, loc);
700708
for (int64_t i = 0; i < rank; i++) {
701709
// TODO: provide utility function for loop sequences that only contains
702710
// one for loop?
703711
loopEmitter.enterNewLoopSeq(rewriter, loc, 0, static_cast<size_t>(i));
704-
loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i);
712+
// Note that reduc will be taken care of by loop emitter and get updated
713+
// in place.
714+
loopEmitter.enterLoopOverTensorAtDim(rewriter, loc, 0, i, reduc);
705715
}
706716

707717
SmallVector<Value, 4> coords;
@@ -716,15 +726,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
716726
: rewriter.create<memref::LoadOp>(loc, vals, coords);
717727

718728
// 2. Inline the block in the foreach operator.
719-
Block::iterator inlinePos = rewriter.getInsertionPoint();
720729
Block *srcBlock = op.getBody();
721-
// Remove sparse_tensor.yield.
722-
rewriter.eraseOp(srcBlock->getTerminator());
723-
724-
for (int64_t i = 0; i < rank; i++) {
725-
loopEmitter.exitCurrentLoop(rewriter, loc);
726-
loopEmitter.exitCurrentLoopSeq();
727-
}
728730

729731
SmallVector<Value, 4> args;
730732
// Remap coordinates.
@@ -734,11 +736,33 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
734736
}
735737
// Remap value.
736738
args.push_back(val);
739+
// Remap reduction variables.
740+
args.append(reduc);
741+
742+
// Remove sparse_tensor.yield.
743+
SmallVector<Value> reducValue = srcBlock->getTerminator()->getOperands();
744+
rewriter.eraseOp(srcBlock->getTerminator());
737745

738746
// Inline body.
739-
rewriter.mergeBlockBefore(srcBlock, &*inlinePos, args);
740-
// delete the foreach operator.
741-
rewriter.eraseOp(op);
747+
if (!reducValue.empty()) {
748+
rewriter.mergeBlocks(srcBlock, rewriter.getBlock(), args);
749+
} else {
750+
// This is annoying, since scf.for inserts a implicit yield op when
751+
// there is no reduction variable upon creation, in this case we need to
752+
// merge the block *before* the yield op.
753+
rewriter.mergeBlockBefore(srcBlock, &*rewriter.getInsertionPoint(), args);
754+
}
755+
756+
for (int64_t i = 0; i < rank; i++) {
757+
// Link the reduction chain. Note that loop emitter update the reducValue
758+
// in place.
759+
loopEmitter.exitCurrentLoop(rewriter, loc, reducValue);
760+
loopEmitter.exitCurrentLoopSeq();
761+
}
762+
763+
// Replace the foreach operator with the value returned by the outtermost
764+
// for loop.
765+
rewriter.replaceOp(op, reducValue);
742766
return success();
743767
}
744768
};
@@ -801,7 +825,8 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
801825
.getResult(0);
802826
Type eltTp = dstTp.getElementType();
803827
Value value = genAllocaScalar(rewriter, loc, eltTp);
804-
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, c0, nnz, c1);
828+
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, c0, nnz, c1,
829+
ArrayRef<Value>(cooBuffer));
805830
rewriter.setInsertionPointToStart(forOp.getBody());
806831

807832
SmallString<18> getNextFuncName{"getSparseTensorReaderNext",
@@ -816,13 +841,17 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
816841
loc, indices, constantIndex(rewriter, loc, i)));
817842
}
818843
Value v = rewriter.create<memref::LoadOp>(loc, value);
819-
rewriter.create<InsertOp>(loc, v, cooBuffer, indicesArray);
844+
auto t = rewriter.create<InsertOp>(loc, v, forOp.getRegionIterArg(0),
845+
indicesArray);
846+
rewriter.create<scf::YieldOp>(loc, ArrayRef<Value>(t));
820847
rewriter.setInsertionPointAfter(forOp);
848+
// Link SSA chain.
849+
cooBuffer = forOp.getResult(0);
821850

822851
// Release the sparse tensor reader.
823852
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
824853
EmitCInterface::Off);
825-
854+
cooBuffer = rewriter.create<LoadOp>(loc, cooBuffer, true);
826855
Value newOp = rewriter.replaceOpWithNewOp<ConvertOp>(op, dstTp, cooBuffer);
827856

828857
// Release the unordered COO tensor buffer.

mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
116116
// CHECK-RWT: %[[V:.*]] = tensor.extract %[[A]]{{\[}}%[[FI]], %[[FJ]]] : tensor<2x4xf64>
117117
// CHECK-RWT: %[[NZ:.*]] = arith.cmpf une, %[[V]], %[[F0]] : f64
118118
// CHECK-RWT: scf.if %[[NZ]] {
119+
// // FIXME: the SSA chain is broken here!
119120
// CHECK-RWT: %{{.*}} = sparse_tensor.insert %[[V]] into %[[COO]]{{\[}}%[[FI]], %[[FJ]]]
120121
// CHECK-RWT: }
121122
// CHECK-RWT: }
@@ -126,11 +127,13 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
126127
// CHECK-RWT: %[[V2:.*]] = sparse_tensor.values %[[COO]]
127128
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]], %[[I1]] jointly %[[V2]]
128129
// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor()
129-
// CHECK-RWT: sparse_tensor.foreach in %[[COO]]
130-
// CHECK-RWT: ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64):
131-
// CHECK-RWT: sparse_tensor.insert %[[FV]] into %[[DST]]{{\[}}%[[FI0]], %[[FI1]]]
130+
// CHECK-RWT: %[[NEW_T:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[DST]])
131+
// CHECK-RWT: ^bb0(%[[FI0:.*]]: index, %[[FI1:.*]]: index, %[[FV:.*]]: f64, %[[R0:.*]]: tensor
132+
// CHECK-RWT: %[[RET:.*]] = sparse_tensor.insert %[[FV]] into %[[R0]]{{\[}}%[[FI0]], %[[FI1]]]
133+
// CHECK-RWT: sparse_tensor.yield %[[RET]]
132134
// CHECK-RWT: }
133-
// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[DST]]
135+
// CHECK-RWT: %[[NT:.*]] = sparse_tensor.load %[[NEW_T]] hasInserts
136+
// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[NT]]
134137
// CHECK-RWT: bufferization.dealloc_tensor %[[COO]]
135138
// CHECK-RWT: return %[[R]] : tensor<2x4xf64, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>
136139
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
@@ -179,6 +182,7 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
179182
// CHECK-RWT: %[[I1r:.*]] = tensor.extract %[[SI]]{{\[}}%[[FI]], %[[C1]]] : tensor<2x2xi64>
180183
// CHECK-RWT: %[[I1:.*]] = arith.index_cast %[[I1r]] : i64 to index
181184
// CHECK-RWT: %[[V:.*]] = tensor.extract %[[SV]]{{\[}}%[[FI]]] : tensor<2xf32>
185+
// // FIXME: the SSA chain is broken here!
182186
// CHECK-RWT: sparse_tensor.insert %[[V]] into %[[COO]]{{\[}}%[[I0]], %[[I1]]]
183187
// CHECK-RWT: }
184188
// CHECK-RWT: %[[TI0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
@@ -187,11 +191,13 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
187191
// CHECK-RWT: %[[TV:.*]] = sparse_tensor.values %[[COO]]
188192
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[TI0]], %[[TI1]] jointly %[[TV]]
189193
// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor()
190-
// CHECK-RWT: sparse_tensor.foreach in %[[COO]]
191-
// CHECK-RWT: ^bb0(%[[F2I0:.*]]: index, %[[F2I1:.*]]: index, %[[F2V:.*]]: f32):
192-
// CHECK-RWT: sparse_tensor.insert %[[F2V]] into %[[DST]]{{\[}}%[[F2I0]], %[[F2I1]]]
194+
// CHECK-RWT: %[[RET:.*]] = sparse_tensor.foreach in %[[COO]] init(%[[DST]])
195+
// CHECK-RWT: ^bb0(%[[F2I0:.*]]: index, %[[F2I1:.*]]: index, %[[F2V:.*]]: f32, %[[R0:.*]]: tensor
196+
// CHECK-RWT: %[[NEW_T:.*]] = sparse_tensor.insert %[[F2V]] into %[[R0]]{{\[}}%[[F2I0]], %[[F2I1]]]
197+
// CHECK-RWT: sparse_tensor.yield %[[NEW_T]]
193198
// CHECK-RWT: }
194-
// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[DST]]
199+
// CHECK-RWT: %[[T:.*]] = sparse_tensor.load %[[RET]] hasInserts
200+
// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[T]]
195201
// CHECK-RWT: bufferization.dealloc_tensor %[[COO]]
196202
// CHECK-RWT: return %[[R]] : tensor<8x7xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed" ] }>>
197203
func.func @sparse_constant() -> tensor<8x7xf32, #CSR>{

mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,13 @@ func.func @sparse_convert_1d_ss(%arg0: tensor<?xf32, #SparseVector64>) -> tensor
9494
// CHECK-RWT: %[[V:.*]] = sparse_tensor.values %[[A]]
9595
// CHECK-RWT: sparse_tensor.sort %[[NNZ]], %[[I0]] jointly %[[V]]
9696
// CHECK-RWT: %[[DST:.*]] = bufferization.alloc_tensor(%[[D]])
97-
// CHECK-RWT: sparse_tensor.foreach in %[[A]]
98-
// CHECK-RWT: ^bb0(%[[FI2:.*]]: index, %[[FV2:.*]]: f32):
99-
// CHECK-RWT: sparse_tensor.insert %[[FV2]] into %[[DST]]{{\[}}%[[FI2]]]
97+
// CHECK-RWT: %[[RET:.*]] = sparse_tensor.foreach in %[[A]] init(%[[DST]])
98+
// CHECK-RWT: ^bb0(%[[FI2:.*]]: index, %[[FV2:.*]]: f32, %[[T:.*]]: tensor<?xf32,
99+
// CHECK-RWT: %[[I:.*]] = sparse_tensor.insert %[[FV2]] into %[[T]]{{\[}}%[[FI2]]]
100+
// CHECK-RWT: sparse_tensor.yield %[[I]]
100101
// CHECK-RWT: }
101-
// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[DST]]
102+
// CHECK-RWT: %[[T:.*]] = sparse_tensor.load %[[RET]] hasInserts
103+
// CHECK-RWT: %[[R:.*]] = sparse_tensor.convert %[[T]]
102104
// CHECK-RWT: return %[[R]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "compressed" ], pointerBitWidth = 32, indexBitWidth = 32 }>>
103105
func.func @sparse_convert(%arg0: tensor<?xf32, #SparseVector64>) -> tensor<?xf32, #SparseVector32> {
104106
%0 = sparse_tensor.convert %arg0 : tensor<?xf32, #SparseVector64> to tensor<?xf32, #SparseVector32>

mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,19 @@
1818
// CHECK: %[[T:.*]] = bufferization.alloc_tensor(%[[D0]], %[[D1]])
1919
// CHECK: %[[N:.*]] = call @getSparseTensorReaderNNZ(%[[R]])
2020
// CHECK: %[[VB:.*]] = memref.alloca()
21-
// CHECK: scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] {
21+
// CHECK: %[[T2:.*]] = scf.for %{{.*}} = %[[C0]] to %[[N]] step %[[C1]] iter_args(%[[A2:.*]] = %[[T]])
2222
// CHECK: func.call @getSparseTensorReaderNextF32(%[[R]], %[[DS]], %[[VB]])
2323
// CHECK: %[[E0:.*]] = memref.load %[[DS]]{{\[}}%[[C0]]]
2424
// CHECK: %[[E1:.*]] = memref.load %[[DS]]{{\[}}%[[C1]]]
2525
// CHECK: %[[V:.*]] = memref.load %[[VB]][]
26-
// CHECK: sparse_tensor.insert %[[V]] into %[[T]]{{\[}}%[[E0]], %[[E1]]]
26+
// CHECK: %[[T1:.*]] = sparse_tensor.insert %[[V]] into %[[A2]]{{\[}}%[[E0]], %[[E1]]]
27+
// CHECK: scf.yield %[[T1]]
2728
// CHECK: }
2829
// CHECK: call @delSparseTensorReader(%[[R]])
29-
// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T]]
30-
// CHECK: bufferization.dealloc_tensor %[[T]]
30+
// CHECK: %[[T3:.*]] = sparse_tensor.load %[[T2]] hasInserts
31+
// CHECK: %[[R:.*]] = sparse_tensor.convert %[[T3]]
32+
// CHECK: bufferization.dealloc_tensor %[[T3]]
3133
// CHECK: return %[[R]]
32-
// CHECK: }
3334
func.func @sparse_new(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
3435
%0 = sparse_tensor.new %arg0 : !llvm.ptr<i8> to tensor<?x?xf32, #CSR>
3536
return %0 : tensor<?x?xf32, #CSR>

0 commit comments

Comments
 (0)