Skip to content

[mlir][sparse] reuse tensor.insert operation to insert elements into … #84987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 0 additions & 42 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -668,48 +668,6 @@ def SparseTensor_CrdTranslateOp : SparseTensor_Op<"crd_translate", [Pure]>,
// refined over time as our sparse abstractions evolve.
//===----------------------------------------------------------------------===//

def SparseTensor_InsertOp : SparseTensor_Op<"insert",
[TypesMatchWith<"value type matches element type of tensor",
"tensor", "value",
"::llvm::cast<ShapedType>($_self).getElementType()">,
AllTypesMatch<["tensor", "result"]>]>,
Arguments<(ins AnyType:$value,
AnySparseTensor:$tensor,
Variadic<Index>:$lvlCoords)>,
Results<(outs AnySparseTensor:$result)> {
string summary = "Inserts a value into the sparse tensor";
string description = [{
Inserts the value into the underlying storage of the tensor at the
given level-coordinates. The arity of `lvlCoords` must match the
level-rank of the tensor. This operation can only be applied when
the tensor materializes unintialized from a `tensor.empty` operation
and the final tensor is constructed with a `load` operation which
has the `hasInserts` attribute set.

The level-properties of the sparse tensor type fully describe what
kind of insertion order is allowed. When all levels have "unique"
and "ordered" properties, for example, insertions should occur in
strict lexicographical level-coordinate order. Other properties
define different insertion regimens. Inserting in a way contrary
to these properties results in undefined behavior.

Note that this operation is "impure" in the sense that even though
the result is modeled through an SSA value, the insertion is eventually
done "in place", and referencing the old SSA value is undefined behavior.
This operation is scheduled to be unified with the dense counterpart
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these were the two reasons why we initially had a completely separate op were (1) order of insertions and (2) use of old ssa value, but not honoring such semantics seems more a lowering problem than something that should be called out in the op anymore.

`tensor.insert` that has pure SSA semantics.

Example:

```mlir
%result = sparse_tensor.insert %val into %tensor[%i,%j] : tensor<1024x1024xf64, #CSR>
```
}];
let assemblyFormat = "$value `into` $tensor `[` $lvlCoords `]` attr-dict"
"`:` type($tensor)";
let hasVerifier = 1;
}

def SparseTensor_PushBackOp : SparseTensor_Op<"push_back",
[TypesMatchWith<"value type matches element type of inBuffer",
"inBuffer", "value",
Expand Down
9 changes: 0 additions & 9 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1741,15 +1741,6 @@ LogicalResult ConcatenateOp::verify() {
return success();
}

LogicalResult InsertOp::verify() {
const auto stt = getSparseTensorType(getTensor());
if (stt.getEncoding().getBatchLvlRank() > 0)
return emitOpError("batched sparse tensor insertion not implemented");
if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
return emitOpError("incorrect number of coordinates");
return success();
}

void PushBackOp::build(OpBuilder &builder, OperationState &result,
Value curSize, Value inBuffer, Value value) {
build(builder, result, curSize, inBuffer, value, Value());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,27 +187,6 @@ struct DisassembleOpInterface
}
};

struct InsertOpInterface : public SparseBufferizableOpInterfaceExternalModel<
InsertOpInterface, sparse_tensor::InsertOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// InsertOp writes to memory.
return true;
}

AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
// InsertOp returns an alias of its operand.
assert(op->getNumResults() == 1);
return {{op->getOpResult(0), BufferRelation::Equivalent}};
}
};

struct NumberOfEntriesOpInterface
: public SparseBufferizableOpInterfaceExternalModel<
NumberOfEntriesOpInterface, sparse_tensor::NumberOfEntriesOp> {
Expand Down Expand Up @@ -324,7 +303,6 @@ void mlir::sparse_tensor::registerBufferizableOpInterfaceExternalModels(
sparse_tensor::ConvertOp::attachInterface<ConvertOpInterface>(*ctx);
sparse_tensor::LoadOp::attachInterface<LoadOpInterface>(*ctx);
sparse_tensor::NewOp::attachInterface<NewOpInterface>(*ctx);
sparse_tensor::InsertOp::attachInterface<InsertOpInterface>(*ctx);
sparse_tensor::NumberOfEntriesOp::attachInterface<
NumberOfEntriesOpInterface>(*ctx);
sparse_tensor::AssembleOp::attachInterface<AssembleOpInterface>(*ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -640,14 +640,14 @@ struct TensorInsertDemapper
using DemapInsRewriter::DemapInsRewriter;
LogicalResult rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
PatternRewriter &rewriter) const {
if (!hasAnySparseResult(op))
if (!hasAnySparseResult(op) || !hasAnyNonIdentityOperandsOrResults(op))
return failure();

Location loc = op.getLoc();
auto stt = getSparseTensorType(op.getResult());
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
CrdTransDirectionKind::dim2lvl);
auto insertOp = rewriter.create<sparse_tensor::InsertOp>(
auto insertOp = rewriter.create<tensor::InsertOp>(
loc, op.getScalar(), adaptor.getDest(), lvlCrd);

Value out = genRemap(rewriter, stt.getEncoding(), insertOp.getResult());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1014,24 +1014,29 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
};

/// Sparse codegen rule for the insert operator.
class SparseInsertConverter : public OpConversionPattern<InsertOp> {
class SparseInsertConverter : public OpConversionPattern<tensor::InsertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto stt = getSparseTensorType(adaptor.getDest());
if (!stt.hasEncoding())
return failure();
assert(stt.isIdentity() && "Run reinterpret-map before conversion.");

Location loc = op.getLoc();
auto desc = getDescriptorFromTensorTuple(adaptor.getTensor());
auto desc = getDescriptorFromTensorTuple(adaptor.getDest());
TypeRange flatSpTensorTps = desc.getFields().getTypes();
SmallVector<Value> params = llvm::to_vector(desc.getFields());
params.append(adaptor.getLvlCoords().begin(), adaptor.getLvlCoords().end());
params.push_back(adaptor.getValue());
SparseInsertGenerator insertGen(op.getTensor().getType(), flatSpTensorTps,
params.append(adaptor.getIndices().begin(), adaptor.getIndices().end());
params.push_back(adaptor.getScalar());
SparseInsertGenerator insertGen(op.getDest().getType(), flatSpTensorTps,
params, /*genCall=*/true);
SmallVector<Value> ret = insertGen.genCallOrInline(rewriter, loc);
// Replace operation with resulting memrefs.
rewriter.replaceOp(op,
genTuple(rewriter, loc, op.getTensor().getType(), ret));
genTuple(rewriter, loc, op.getDest().getType(), ret));
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -580,17 +580,24 @@ class SparseTensorLoadConverter : public OpConversionPattern<LoadOp> {
};

/// Sparse conversion rule for the insertion operator.
class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
class SparseTensorInsertConverter
: public OpConversionPattern<tensor::InsertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(InsertOp op, OpAdaptor adaptor,
matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// Note that the current regime only allows for strict lexicographic
// coordinate order. All values are passed by reference through stack
// allocated memrefs.
Location loc = op->getLoc();
const auto stt = getSparseTensorType(op.getTensor());
const auto stt = getSparseTensorType(op.getDest());

// Dense tensor insertion.
if (!stt.hasEncoding())
return failure();

assert(stt.isIdentity() && "Run reinterpret-map before conversion.");
const auto elemTp = stt.getElementType();
const Level lvlRank = stt.getLvlRank();
Value lvlCoords, vref;
Expand All @@ -608,12 +615,12 @@ class SparseTensorInsertConverter : public OpConversionPattern<InsertOp> {
lvlCoords = genAlloca(rewriter, loc, lvlRank, rewriter.getIndexType());
vref = genAllocaScalar(rewriter, loc, elemTp);
}
storeAll(rewriter, loc, lvlCoords, adaptor.getLvlCoords());
rewriter.create<memref::StoreOp>(loc, adaptor.getValue(), vref);
storeAll(rewriter, loc, lvlCoords, adaptor.getIndices());
rewriter.create<memref::StoreOp>(loc, adaptor.getScalar(), vref);
SmallString<12> name{"lexInsert", primaryTypeFunctionSuffix(elemTp)};
createFuncCall(rewriter, loc, name, {},
{adaptor.getTensor(), lvlCoords, vref}, EmitCInterface::On);
rewriter.replaceOp(op, adaptor.getTensor());
{adaptor.getDest(), lvlCoords, vref}, EmitCInterface::On);
rewriter.replaceOp(op, adaptor.getDest());
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,8 @@ struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
reshapeCvs(builder, loc, expandReass, collapsedSizes, collapsedDcvs,
dstSizes, dstDcvs);

auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
auto t =
builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
builder.create<sparse_tensor::YieldOp>(loc, t);
});

Expand Down Expand Up @@ -901,7 +902,8 @@ struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
SmallVector<Value> dstDcvs;
reshapeCvs(builder, loc, op.getReassociationIndices(), srcSizes,
srcDcvs, dstSizes, dstDcvs);
auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
auto t =
builder.create<tensor::InsertOp>(loc, v, reduc.front(), dstDcvs);
builder.create<sparse_tensor::YieldOp>(loc, t);
});

Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
/*else=*/true);
// True branch.
builder.setInsertionPointToStart(ifValidLexInsert.thenBlock());
Value res = builder.create<InsertOp>(loc, rhs, chain, ivs);
Value res = builder.create<tensor::InsertOp>(loc, rhs, chain, ivs);
builder.create<scf::YieldOp>(loc, res);
// False branch.
builder.setInsertionPointToStart(ifValidLexInsert.elseBlock());
Expand All @@ -438,7 +438,8 @@ static void genInsertionStore(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
env.updateInsertionChain(ifValidLexInsert.getResult(0));
} else {
// Generates regular insertion chain.
env.updateInsertionChain(builder.create<InsertOp>(loc, rhs, chain, ivs));
env.updateInsertionChain(
builder.create<tensor::InsertOp>(loc, rhs, chain, ivs));
}
return;
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/SparseTensor/codegen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -643,7 +643,7 @@ func.func @sparse_compression_unordered(%tensor: tensor<8x8xf64, #UCSR>,
// CHECK: %[[R:.*]]:4 = call @_insert_compressed_128_f64_0_0(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SV> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
%0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SV>
%1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SV>
return %1 : tensor<128xf64, #SV>
}
Expand All @@ -666,7 +666,7 @@ func.func @sparse_insert(%arg0: tensor<128xf64, #SV>, %arg1: index, %arg2: f64)
// CHECK: %[[R:.*]]:4 = call @_insert_compressed_128_f64_64_32(%[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A5]], %[[A6]])
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
%0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
%1 = sparse_tensor.load %0 hasInserts : tensor<128xf64, #SparseVector>
return %1 : tensor<128xf64, #SparseVector>
}
Expand All @@ -690,7 +690,7 @@ func.func @sparse_insert_typed(%arg0: tensor<128xf64, #SparseVector>, %arg1: ind
// CHECK: %[[R:.*]]:4 = call @_insert_compressed_nonunique_singleton_5_6_f64_0_0(%[[A0]], %[[A1]], %[[A2]], %[[A3]], %[[A4]], %[[A4]], %[[A5]])
// CHECK: return %[[R]]#0, %[[R]]#1, %[[R]]#2, %[[R]]#3
func.func @sparse_insert_coo(%arg0: tensor<5x6xf64, #Coo>, %arg1: index, %arg2: f64) -> tensor<5x6xf64, #Coo> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
%0 = tensor.insert %arg2 into %arg0[%arg1, %arg1] : tensor<5x6xf64, #Coo>
%1 = sparse_tensor.load %0 hasInserts : tensor<5x6xf64, #Coo>
return %1 : tensor<5x6xf64, #Coo>
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/constant_index_map.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
// CHECK: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_7]]{{\[}}%[[VAL_3]], %[[VAL_9]]] : memref<1x77xi1>
// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i1
// CHECK: %[[VAL_14:.*]] = sparse_tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}>
// CHECK: %[[VAL_14:.*]] = tensor.insert %[[VAL_13]] into %[[VAL_10]]{{\[}}%[[VAL_9]]] : tensor<77xi1, #{{.*}}>
// CHECK: scf.yield %[[VAL_14]] : tensor<77xi1, #{{.*}}>
// CHECK: }
// CHECK: %[[VAL_15:.*]] = sparse_tensor.load %[[VAL_16:.*]] hasInserts : tensor<77xi1, #{{.*}}>
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ func.func @sparse_reconstruct_ins(%arg0: tensor<128xf32, #SparseVector>) -> tens
func.func @sparse_insert(%arg0: tensor<128xf32, #SparseVector>,
%arg1: index,
%arg2: f32) -> tensor<128xf32, #SparseVector> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
%0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf32, #SparseVector>
return %0 : tensor<128xf32, #SparseVector>
}

Expand Down
18 changes: 0 additions & 18 deletions mlir/test/Dialect/SparseTensor/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -290,24 +290,6 @@ func.func @sparse_unannotated_load(%arg0: tensor<16x32xf64>) -> tensor<16x32xf64

// -----

func.func @sparse_unannotated_insert(%arg0: tensor<128xf64>, %arg1: index, %arg2: f64) {
// expected-error@+1 {{'sparse_tensor.insert' 'tensor' must be sparse tensor of any type values, but got 'tensor<128xf64>'}}
sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64>
return
}

// -----

#CSR = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d1 : compressed)}>

func.func @sparse_wrong_arity_insert(%arg0: tensor<128x64xf64, #CSR>, %arg1: index, %arg2: f64) {
// expected-error@+1 {{'sparse_tensor.insert' op incorrect number of coordinates}}
sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128x64xf64, #CSR>
return
}

// -----

func.func @sparse_push_back(%arg0: index, %arg1: memref<?xf64>, %arg2: f32) -> (memref<?xf64>, index) {
// expected-error@+1 {{'sparse_tensor.push_back' op failed to verify that value type matches element type of inBuffer}}
%0:2 = sparse_tensor.push_back %arg0, %arg1, %arg2 : index, memref<?xf64>, f32
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SparseTensor/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,10 @@ func.func @sparse_load_ins(%arg0: tensor<16x32xf64, #DenseMatrix>) -> tensor<16x
// CHECK-SAME: %[[A:.*]]: tensor<128xf64, #sparse{{[0-9]*}}>,
// CHECK-SAME: %[[B:.*]]: index,
// CHECK-SAME: %[[C:.*]]: f64)
// CHECK: %[[T:.*]] = sparse_tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
// CHECK: %[[T:.*]] = tensor.insert %[[C]] into %[[A]][%[[B]]] : tensor<128xf64, #{{.*}}>
// CHECK: return %[[T]] : tensor<128xf64, #{{.*}}>
func.func @sparse_insert(%arg0: tensor<128xf64, #SparseVector>, %arg1: index, %arg2: f64) -> tensor<128xf64, #SparseVector> {
%0 = sparse_tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
%0 = tensor.insert %arg2 into %arg0[%arg1] : tensor<128xf64, #SparseVector>
return %0 : tensor<128xf64, #SparseVector>
}

Expand Down
Loading