Skip to content

Commit a5e8a22

Browse files
author
Peiming Liu
committed
[mlir][sparse] reuse tensor.insert operation to insert elements into a sparse tensor.
1 parent 9ac0315 commit a5e8a22

26 files changed

+106
-182
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -668,48 +668,6 @@ def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
668668
// refined over time as our sparse abstractions evolve.
669669
//===----------------------------------------------------------------------===//
670670

671-
def SparseTensor_InsertOp : SparseTensor_Op<"insert",
672-
[TypesMatchWith<"value type matches element type of tensor",
673-
"tensor", "value",
674-
"::llvm::cast<ShapedType>($_self).getElementType()">,
675-
AllTypesMatch<["tensor", "result"]>]>,
676-
Arguments<(ins AnyType:$value,
677-
AnySparseTensor:$tensor,
678-
Variadic<Index>:$lvlCoords)>,
679-
Results<(outs AnySparseTensor:$result)> {
680-
string summary = "Inserts a value into the sparse tensor";
681-
string description = [{
682-
Inserts the value into the underlying storage of the tensor at the
683-
given level-coordinates. The arity of `lvlCoords` must match the
684-
level-rank of the tensor. This operation can only be applied when
685-
the tensor materializes unintialized from a `tensor.empty` operation
686-
and the final tensor is constructed with a `load` operation which
687-
has the `hasInserts` attribute set.
688-
689-
The level-properties of the sparse tensor type fully describe what
690-
kind of insertion order is allowed. When all levels have "unique"
691-
and "ordered" properties, for example, insertions should occur in
692-
strict lexicographical level-coordinate order. Other properties
693-
define different insertion regimens. Inserting in a way contrary
694-
to these properties results in undefined behavior.
695-
696-
Note that this operation is "impure" in the sense that even though
697-
the result is modeled through an SSA value, the insertion is eventually
698-
done "in place", and referencing the old SSA value is undefined behavior.
699-
This operation is scheduled to be unified with the dense counterpart
700-
`tensor.insert` that has pure SSA semantics.
701-
702-
Example:
703-
704-
```mlir
705-
%result = sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
706-
```
707-
}];
708-
let assemblyFormat = "$value `into` $tensor `[` $lvlCoords `]` attr-dict"
709-
"`:` type($tensor)";
710-
let hasVerifier = 1;
711-
}
712-
713671
def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
714672
[TypesMatchWith<"value type matches element type of inBuffer",
715673
"inBuffer", "value",

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,15 +1741,6 @@ LogicalResult ConcatenateOp::verify() {
17411741
return success();
17421742
}
17431743

1744-
LogicalResult InsertOp::verify() {
1745-
const auto stt = getSparseTensorType(getTensor());
1746-
if (stt.getEncoding().getBatchLvlRank() > 0)
1747-
return emitOpError("batched sparse tensor insertion not implemented");
1748-
if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
1749-
return emitOpError("incorrect number of coordinates");
1750-
return success();
1751-
}
1752-
17531744
void PushBackOp::build(OpBuilder &builder, OperationState &result,
17541745
Value curSize, Value inBuffer, Value value) {
17551746
build(builder, result, curSize, inBuffer, value, Value());

mlir/lib/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -187,27 +187,6 @@ struct DisassembleOpInterface
187187
}
188188
};
189189

190-
struct InsertOpInterface : public SparseBufferizableOpInterfaceExternalModel<
191-
InsertOpInterface, sparse_tensor::InsertOp> {
192-
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
193-
const AnalysisState &state) const {
194-
return true;
195-
}
196-
197-
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
198-
const AnalysisState &state) const {
199-
// InsertOp writes to memory.
200-
return true;
201-
}
202-
203-
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
204-
const AnalysisState &state) const {
205-
// InsertOp returns an alias of its operand.
206-
assert(op->getNumResults() == 1);
207-
return {{op->getOpResult(0), BufferRelation::Equivalent}};
208-
}
209-
};
210-
211190
struct NumberOfEntriesOpInterface
212191
: public SparseBufferizableOpInterfaceExternalModel<
213192
NumberOfEntriesOpInterface, sparse_tensor::NumberOfEntriesOp> {
@@ -324,7 +303,6 @@ void mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(
324303
sparse_tensor::ConvertOp::attachInterface<ConvertOpInterface>(*ctx);
325304
sparse_tensor::LoadOp::attachInterface<LoadOpInterface>(*ctx);
326305
sparse_tensor::NewOp::attachInterface<NewOpInterface>(*ctx);
327-
sparse_tensor::InsertOp::attachInterface<InsertOpInterface>(*ctx);
328306
sparse_tensor::NumberOfEntriesOp::attachInterface<
329307
NumberOfEntriesOpInterface>(*ctx);
330308
sparse_tensor::AssembleOp::attachInterface<AssembleOpInterface>(*ctx);

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,14 +640,14 @@ struct TensorInsertDemapper
640640
using DemapInsRewriter::DemapInsRewriter;
641641
LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
642642
PatternRewriter &rewriter) const {
643-
if (!hasAnySparseResult(op))
643+
if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
644644
return failure();
645645

646646
Location loc = op.getLoc();
647647
auto stt = getSparseTensorType(op.getResult());
648648
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
649649
CrdTransDirectionKind::dim2lvl);
650-
auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
650+
auto insertOp = rewriter.create<tensor::InsertOp>(
651651
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
652652

653653
Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,24 +1014,29 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
10141014
};
10151015

10161016
/// Sparse codegen rule for the insert operator.
1017-
class SparseInsertConverter : public OpConversionPattern<InsertOp> {
1017+
class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
10181018
public:
10191019
using OpConversionPattern::OpConversionPattern;
10201020
LogicalResult
1021-
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
1021+
matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
10221022
ConversionPatternRewriter &rewriter) const override {
1023+
auto stt = getSparseTensorType(adaptor.getDest());
1024+
if (!stt.hasEncoding())
1025+
return failure();
1026+
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
1027+
10231028
Location loc = op.getLoc();
1024-
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
1029+
auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
10251030
TypeRange flatSpTensorTps = desc.getFields().getTypes();
10261031
SmallVector<Value> params = llvm::to_vector(desc.getFields());
1027-
params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
1028-
params.push_back(adaptor.getValue());
1029-
SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
1032+
params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
1033+
params.push_back(adaptor.getScalar());
1034+
SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
10301035
params, /*genCall=*/true);
10311036
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
10321037
// Replace operation with resulting memrefs.
10331038
rewriter.replaceOp(op,
1034-
genTuple(rewriter, loc, op.getTensor().getType(), ret));
1039+
genTuple(rewriter, loc, op.getDest().getType(), ret));
10351040
return success();
10361041
}
10371042
};

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -580,17 +580,24 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
580580
};
581581

582582
/// Sparse conversion rule for the insertion operator.
583-
class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
583+
class SparseTensorInsertConverter
584+
: public OpConversionPattern<tensor::InsertOp> {
584585
public:
585586
using OpConversionPattern::OpConversionPattern;
586587
LogicalResult
587-
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
588+
matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
588589
ConversionPatternRewriter &rewriter) const override {
589590
// Note that the current regime only allows for strict lexicographic
590591
// coordinate order. All values are passed by reference through stack
591592
// allocated memrefs.
592593
Location loc = op->getLoc();
593-
const auto stt = getSparseTensorType(op.getTensor());
594+
const auto stt = getSparseTensorType(op.getDest());
595+
596+
// Dense tensor insertion.
597+
if (!stt.hasEncoding())
598+
return failure();
599+
600+
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
594601
const auto elemTp = stt.getElementType();
595602
const Level lvlRank = stt.getLvlRank();
596603
Value lvlCoords, vref;
@@ -608,12 +615,12 @@ class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
608615
lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
609616
vref = genAllocaScalar(rewriter, loc, elemTp);
610617
}
611-
storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
612-
rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
618+
storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
619+
rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
613620
SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
614621
createFuncCall(rewriter, loc, name, {},
615-
{adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
616-
rewriter.replaceOp(op, adaptor.getTensor());
622+
{adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
623+
rewriter.replaceOp(op, adaptor.getDest());
617624
return success();
618625
}
619626
};

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
817817
reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
818818
dstSizes, dstDcvs);
819819

820-
auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
820+
auto t =
821+
builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
821822
builder.create<sparse_tensor::YieldOp>(loc, t);
822823
});
823824

@@ -901,7 +902,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
901902
SmallVector<Value> dstDcvs;
902903
reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
903904
srcDcvs, dstSizes, dstDcvs);
904-
auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
905+
auto t =
906+
builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
905907
builder.create<sparse_tensor::YieldOp>(loc, t);
906908
});
907909

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
428428
/*else=*/true);
429429
// True branch.
430430
builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
431-
Value res = builder.create<InsertOp>(loc, rhs, chain, ivs);
431+
Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
432432
builder.create<scf::YieldOp>(loc, res);
433433
// False branch.
434434
builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
@@ -438,7 +438,8 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
438438
env.updateInsertionChain(ifValidLexInsert.getResult(0));
439439
} else {
440440
// Generates regular insertion chain.
441-
env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
441+
env.updateInsertionChain(
442+
builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
442443
}
443444
return;
444445
}

mlir/test/Dialect/SparseTensor/codegen.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
643643
// CHECK: %[[R:.*]]:4 = call @_insert_compressed_128_f64_0_0(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
644644
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
645645
func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
646-
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
646+
%0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
647647
%1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
648648
return %1 : tensor<128xf64, #SV>
649649
}
@@ -666,7 +666,7 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
666666
// CHECK: %[[R:.*]]:4 = call @_insert_compressed_128_f64_64_32(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
667667
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
668668
func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
669-
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
669+
%0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
670670
%1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector>
671671
return %1 : tensor<128xf64, #SparseVector>
672672
}
@@ -690,7 +690,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
690690
// CHECK: %[[R:.*]]:4 = call @_insert_compressed_nonunique_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
691691
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
692692
func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> {
693-
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
693+
%0 = tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
694694
%1 = sparse_tensor.load %0 hasInserts : tensor<5x6xf64, #Coo>
695695
return %1 : tensor<5x6xf64, #Coo>
696696
}

mlir/test/Dialect/SparseTensor/constant_index_map.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
2121
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
2222
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i1
23-
// CHECK: %[[VAL_14:.*]] = sparse_tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}>
23+
// CHECK: %[[VAL_14:.*]] = tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}>
2424
// CHECK: scf.yield %[[VAL_14]] : tensor<77xi1, #{{.*}}>
2525
// CHECK: }
2626
// CHECK: %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_16:.*]] hasInserts : tensor<77xi1, #{{.*}}>

mlir/test/Dialect/SparseTensor/conversion.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
318318
func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
319319
%arg1: index,
320320
%arg2: f32) -> tensor<128xf32, #SparseVector> {
321-
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
321+
%0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
322322
return %0 : tensor<128xf32, #SparseVector>
323323
}
324324

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -290,24 +290,6 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64
290290

291291
// -----
292292

293-
func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: index, %arg2: f64) {
294-
// expected-error@+1 {{'sparse_tensor.insert' 'tensor' must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
295-
sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64>
296-
return
297-
}
298-
299-
// -----
300-
301-
#CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>
302-
303-
func.func @sparse_wrong_arity_insert(%arg0: tensor<128x64xf64, #CSR>, %arg1: index, %arg2: f64) {
304-
// expected-error@+1 {{'sparse_tensor.insert' op incorrect number of coordinates}}
305-
sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128x64xf64, #CSR>
306-
return
307-
}
308-
309-
// -----
310-
311293
func.func @sparse_push_back(%arg0: index, %arg1: memref<?xf64>, %arg2: f32) -> (memref<?xf64>, index) {
312294
// expected-error@+1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}}
313295
%0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref<?xf64>, f32

mlir/test/Dialect/SparseTensor/roundtrip.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,10 @@ func.func @sparse_load_ins(%arg0: tensor<16x32xf64, #DenseMatrix>) -> tensor<16x
311311
// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse{{[0-9]*}}>,
312312
// CHECK-SAME: %[[B:.*]]: index,
313313
// CHECK-SAME: %[[C:.*]]: f64)
314-
// CHECK: %[[T:.*]] = sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
314+
// CHECK: %[[T:.*]] = tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
315315
// CHECK: return %[[T]] : tensor<128xf64, #{{.*}}>
316316
func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
317-
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
317+
%0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
318318
return %0 : tensor<128xf64, #SparseVector>
319319
}
320320

0 commit comments

Comments
 (0)