-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] reuse tensor.insert operation to insert elements into … #84987
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) Changes…a sparse tensor. Patch is 48.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/84987.diff 26 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index feed15d6af0544..0498576fcffc51 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -668,48 +668,6 @@ def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
// refined over time as our sparse abstractions evolve.
//===----------------------------------------------------------------------===//
-def SparseTensor_InsertOp : SparseTensor_Op<"insert",
- [TypesMatchWith<"value type matches element type of tensor",
- "tensor", "value",
- "::llvm::cast<ShapedType>($_self).getElementType()">,
- AllTypesMatch<["tensor", "result"]>]>,
- Arguments<(ins AnyType:$value,
- AnySparseTensor:$tensor,
- Variadic<Index>:$lvlCoords)>,
- Results<(outs AnySparseTensor:$result)> {
- string summary = "Inserts a value into the sparse tensor";
- string description = [{
- Inserts the value into the underlying storage of the tensor at the
- given level-coordinates. The arity of `lvlCoords` must match the
- level-rank of the tensor. This operation can only be applied when
- the tensor materializes unintialized from a `tensor.empty` operation
- and the final tensor is constructed with a `load` operation which
- has the `hasInserts` attribute set.
-
- The level-properties of the sparse tensor type fully describe what
- kind of insertion order is allowed. When all levels have "unique"
- and "ordered" properties, for example, insertions should occur in
- strict lexicographical level-coordinate order. Other properties
- define different insertion regimens. Inserting in a way contrary
- to these properties results in undefined behavior.
-
- Note that this operation is "impure" in the sense that even though
- the result is modeled through an SSA value, the insertion is eventually
- done "in place", and referencing the old SSA value is undefined behavior.
- This operation is scheduled to be unified with the dense counterpart
- `tensor.insert` that has pure SSA semantics.
-
- Example:
-
- ```mlir
- %result = sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
- ```
- }];
- let assemblyFormat = "$value `into` $tensor `[` $lvlCoords `]` attr-dict"
- "`:` type($tensor)";
- let hasVerifier = 1;
-}
-
def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
[TypesMatchWith<"value type matches element type of inBuffer",
"inBuffer", "value",
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index c19907a945d3bb..7750efdd9add0f 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -1741,15 +1741,6 @@ LogicalResult ConcatenateOp::verify() {
return success();
}
-LogicalResult InsertOp::verify() {
- const auto stt = getSparseTensorType(getTensor());
- if (stt.getEncoding().getBatchLvlRank() > 0)
- return emitOpError("batched sparse tensor insertion not implemented");
- if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
- return emitOpError("incorrect number of coordinates");
- return success();
-}
-
void PushBackOp::build(OpBuilder &builder, OperationState &result,
Value curSize, Value inBuffer, Value value) {
build(builder, result, curSize, inBuffer, value, Value());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 3f4ae1f67de150..a942a721e218fe 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -187,27 +187,6 @@ struct DisassembleOpInterface
}
};
-struct InsertOpInterface : public SparseBufferizableOpInterfaceExternalModel<
- InsertOpInterface, sparse_tensor::InsertOp> {
- bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
- const AnalysisState &state) const {
- return true;
- }
-
- bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
- const AnalysisState &state) const {
- // InsertOp writes to memory.
- return true;
- }
-
- AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
- const AnalysisState &state) const {
- // InsertOp returns an alias of its operand.
- assert(op->getNumResults() == 1);
- return {{op->getOpResult(0), BufferRelation::Equivalent}};
- }
-};
-
struct NumberOfEntriesOpInterface
: public SparseBufferizableOpInterfaceExternalModel<
NumberOfEntriesOpInterface, sparse_tensor::NumberOfEntriesOp> {
@@ -324,7 +303,6 @@ void mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(
sparse_tensor::ConvertOp::attachInterface<ConvertOpInterface>(*ctx);
sparse_tensor::LoadOp::attachInterface<LoadOpInterface>(*ctx);
sparse_tensor::NewOp::attachInterface<NewOpInterface>(*ctx);
- sparse_tensor::InsertOp::attachInterface<InsertOpInterface>(*ctx);
sparse_tensor::NumberOfEntriesOp::attachInterface<
NumberOfEntriesOpInterface>(*ctx);
sparse_tensor::AssembleOp::attachInterface<AssembleOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
index fbe2fc31ab8b15..f93b59de29e57b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp
@@ -640,14 +640,14 @@ struct TensorInsertDemapper
using DemapInsRewriter::DemapInsRewriter;
LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
PatternRewriter &rewriter) const {
- if (!hasAnySparseResult(op))
+ if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
return failure();
Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
CrdTransDirectionKind::dim2lvl);
- auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
+ auto insertOp = rewriter.create<tensor::InsertOp>(
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 44c5d4dbe485bf..56c51c1ab1a84f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -1014,24 +1014,28 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
};
/// Sparse codegen rule for the insert operator.
-class SparseInsertConverter : public OpConversionPattern<InsertOp> {
+class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(InsertOp op, OpAdaptor adaptor,
+ matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ if (auto stt = getSparseTensorType(adaptor.getDest()); !stt.hasEncoding()) {
+ assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
+ return failure();
+ }
Location loc = op.getLoc();
- auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
+ auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
TypeRange flatSpTensorTps = desc.getFields().getTypes();
SmallVector<Value> params = llvm::to_vector(desc.getFields());
- params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
- params.push_back(adaptor.getValue());
- SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
+ params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
+ params.push_back(adaptor.getScalar());
+ SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op,
- genTuple(rewriter, loc, op.getTensor().getType(), ret));
+ genTuple(rewriter, loc, op.getDest().getType(), ret));
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 010c3aa58b72c7..0937c10f257283 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -580,17 +580,24 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
};
/// Sparse conversion rule for the insertion operator.
-class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
+class SparseTensorInsertConverter
+ : public OpConversionPattern<tensor::InsertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
- matchAndRewrite(InsertOp op, OpAdaptor adaptor,
+ matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Note that the current regime only allows for strict lexicographic
// coordinate order. All values are passed by reference through stack
// allocated memrefs.
Location loc = op->getLoc();
- const auto stt = getSparseTensorType(op.getTensor());
+ const auto stt = getSparseTensorType(op.getDest());
+
+ // Dense tensor insertion.
+ if (!stt.hasEncoding())
+ return failure();
+
+ assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
const auto elemTp = stt.getElementType();
const Level lvlRank = stt.getLvlRank();
Value lvlCoords, vref;
@@ -608,12 +615,12 @@ class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
vref = genAllocaScalar(rewriter, loc, elemTp);
}
- storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
- rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
+ storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
+ rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
createFuncCall(rewriter, loc, name, {},
- {adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
- rewriter.replaceOp(op, adaptor.getTensor());
+ {adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
+ rewriter.replaceOp(op, adaptor.getDest());
return success();
}
};
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a65bce78d095cf..17f70d0796ccfc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -817,7 +817,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
dstSizes, dstDcvs);
- auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
+ auto t =
+ builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
builder.create<sparse_tensor::YieldOp>(loc, t);
});
@@ -901,7 +902,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
SmallVector<Value> dstDcvs;
reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
srcDcvs, dstSizes, dstDcvs);
- auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
+ auto t =
+ builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
builder.create<sparse_tensor::YieldOp>(loc, t);
});
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 1fb70ed5035c03..cd046b670d9a8e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -428,7 +428,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
/*else=*/true);
// True branch.
builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
- Value res = builder.create<InsertOp>(loc, rhs, chain, ivs);
+ Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
builder.create<scf::YieldOp>(loc, res);
// False branch.
builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
@@ -438,7 +438,8 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
env.updateInsertionChain(ifValidLexInsert.getResult(0));
} else {
// Generates regular insertion chain.
- env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
+ env.updateInsertionChain(
+ builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
}
return;
}
diff --git a/mlir/test/Dialect/SparseTensor/codegen.mlir b/mlir/test/Dialect/SparseTensor/codegen.mlir
index b63762485c961f..40bfa1e4e2a501 100644
--- a/mlir/test/Dialect/SparseTensor/codegen.mlir
+++ b/mlir/test/Dialect/SparseTensor/codegen.mlir
@@ -643,7 +643,7 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
// CHECK: %[[R:.*]]:4 = call @_insert_compressed_128_f64_0_0(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
- %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
+ %0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
%1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
return %1 : tensor<128xf64, #SV>
}
@@ -666,7 +666,7 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
// CHECK: %[[R:.*]]:4 = call @_insert_compressed_128_f64_64_32(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
- %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
+ %0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
%1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector>
return %1 : tensor<128xf64, #SparseVector>
}
@@ -690,7 +690,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
// CHECK: %[[R:.*]]:4 = call @_insert_compressed_nonunique_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> {
- %0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
+ %0 = tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
%1 = sparse_tensor.load %0 hasInserts : tensor<5x6xf64, #Coo>
return %1 : tensor<5x6xf64, #Coo>
}
diff --git a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
index eaef6a31585292..f9559ce648c783 100644
--- a/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
+++ b/mlir/test/Dialect/SparseTensor/constant_index_map.mlir
@@ -20,7 +20,7 @@
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i1
-// CHECK: %[[VAL_14:.*]] = sparse_tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}>
+// CHECK: %[[VAL_14:.*]] = tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}>
// CHECK: scf.yield %[[VAL_14]] : tensor<77xi1, #{{.*}}>
// CHECK: }
// CHECK: %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_16:.*]] hasInserts : tensor<77xi1, #{{.*}}>
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 465f2108626606..f23f6ac4f181e2 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -318,7 +318,7 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
%arg1: index,
%arg2: f32) -> tensor<128xf32, #SparseVector> {
- %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
+ %0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
return %0 : tensor<128xf32, #SparseVector>
}
diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir
index eac97f702f58bd..48f28ef390ed53 100644
--- a/mlir/test/Dialect/SparseTensor/invalid.mlir
+++ b/mlir/test/Dialect/SparseTensor/invalid.mlir
@@ -290,24 +290,6 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64
// -----
-func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: index, %arg2: f64) {
- // expected-error@+1 {{'sparse_tensor.insert' 'tensor' must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
- sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64>
- return
-}
-
-// -----
-
-#CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
-
-func.func @sparse_wrong_arity_insert(%arg0: tensor<128x64xf64, #CSR>, %arg1: index, %arg2: f64) {
- // expected-error@+1 {{'sparse_tensor.insert' op incorrect number of coordinates}}
- sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128x64xf64, #CSR>
- return
-}
-
-// -----
-
func.func @sparse_push_back(%arg0: index, %arg1: memref<?xf64>, %arg2: f32) -> (memref<?xf64>, index) {
// expected-error@+1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}}
%0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref<?xf64>, f32
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
index 41094fbad9218f..e9e458e805ba47 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir
@@ -311,10 +311,10 @@ func.func @sparse_load_ins(%arg0: tensor<16x32xf64, #DenseMatrix>) -> tensor<16x
// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse{{[0-9]*}}>,
// CHECK-SAME: %[[B:.*]]: index,
// CHECK-SAME: %[[C:.*]]: f64)
-// CHECK: %[[T:.*]] = sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
+// CHECK: %[[T:.*]] = tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
// CHECK: return %[[T]] : tensor<128xf64, #{{.*}}>
func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
- %0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
+ ...
[truncated]
|
Note that this operation is "impure" in the sense that even though | ||
the result is modeled through an SSA value, the insertion is eventually | ||
done "in place", and referencing the old SSA value is undefined behavior. | ||
This operation is scheduled to be unified with the dense counterpart |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these were the two reasons why we initially had a completely separate op were (1) order of insertions and (2) use of old ssa value, but not honoring such semantics seems more a lowering problem than something that should be called out in the op anymore.
…a sparse tensor.