Skip to content

Commit 951a363

Browse files
author
Peiming Liu
authored
[mlir][sparse] implement sparse_tensor.extract_value operation. (#101220)
1 parent 2aa96fc commit 951a363

File tree

4 files changed

+40
-7
lines changed

4 files changed

+40
-7
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "Utils/CodegenUtils.h"
33
#include "Utils/SparseTensorIterator.h"
44

5+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
56
#include "mlir/Dialect/SCF/IR/SCF.h"
67
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
78
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -10,8 +11,8 @@
1011
using namespace mlir;
1112
using namespace mlir::sparse_tensor;
1213

13-
void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
14-
SmallVectorImpl<Type> &fields) {
14+
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
15+
SmallVectorImpl<Type> &fields) {
1516
// Position and coordinate buffer in the sparse structure.
1617
if (enc.getLvlType(lvl).isWithPosLT())
1718
fields.push_back(enc.getPosMemRefType());
@@ -71,6 +72,21 @@ class ExtractIterSpaceConverter
7172
}
7273
};
7374

75+
/// Sparse codegen rule for number of entries operator.
76+
class ExtractValOpConverter : public OneToNOpConversionPattern<ExtractValOp> {
77+
public:
78+
using OneToNOpConversionPattern::OneToNOpConversionPattern;
79+
LogicalResult
80+
matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
81+
OneToNPatternRewriter &rewriter) const override {
82+
Location loc = op.getLoc();
83+
Value pos = adaptor.getIterator().back();
84+
Value valBuf = rewriter.create<ToValuesOp>(loc, op.getTensor());
85+
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
86+
return success();
87+
}
88+
};
89+
7490
class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
7591
public:
7692
using OneToNOpConversionPattern::OneToNOpConversionPattern;
@@ -193,6 +209,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
193209
TypeConverter &converter, RewritePatternSet &patterns) {
194210

195211
IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
196-
patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
197-
converter, patterns.getContext());
212+
patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
213+
SparseIterateOpConverter>(converter, patterns.getContext());
198214
}

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,9 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
357357
const auto pos = env.emitter().getValPosits(tid);
358358
assert(!pos.empty());
359359
args.append(pos);
360+
// Simply returns the tensor to extract value using iterators.
361+
if (env.options().sparseEmitStrategy == SparseEmitStrategy::kSparseIterator)
362+
return t->get();
360363
} else {
361364
// For dense tensors we push all level's coordinates onto `args`.
362365
const Level lvlRank = stt.getLvlRank();
@@ -512,9 +515,16 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
512515
return genInsertionLoadReduce(env, builder, t);
513516
return genInsertionLoad(env, builder, t);
514517
}
518+
515519
// Actual load.
516520
SmallVector<Value> args;
517521
Value ptr = genSubscript(env, builder, t, args);
522+
if (llvm::isa<TensorType>(ptr.getType())) {
523+
assert(env.options().sparseEmitStrategy ==
524+
SparseEmitStrategy::kSparseIterator &&
525+
args.size() == 1);
526+
return builder.create<ExtractValOp>(loc, ptr, args.front());
527+
}
518528
return builder.create<memref::LoadOp>(loc, ptr, args);
519529
}
520530

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,11 @@ class LoopEmitter {
221221
/// Getters.
222222
///
223223
SmallVector<Value> getValPosits(TensorId tid) const {
224+
// Returns the iterator if we are generating sparse (co)iterate-based loops.
225+
if (emitStrategy == SparseEmitStrategy::kSparseIterator)
226+
return {spIterVals[tid].back()};
227+
228+
// Returns {[batch coords], last-level position}.
224229
SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
225230
Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
226231
batchCrds.push_back(lastLvlPos);

mlir/test/Dialect/SparseTensor/sparse_kernels_to_iterator.mlir

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s
1+
// RUN: mlir-opt %s --sparse-reinterpret-map -sparsification="sparse-emit-strategy=sparse-iterator" --cse --sparse-space-collapse --lower-sparse-iteration-to-scf --loop-invariant-code-motion | FileCheck %s
22

33

44
#COO = #sparse_tensor.encoding<{
@@ -7,8 +7,7 @@
77
d1 : singleton(nonunique, soa),
88
d2 : singleton(nonunique, soa),
99
d3 : singleton(soa)
10-
),
11-
explicitVal = 1 : i32
10+
)
1211
}>
1312

1413
// CHECK-LABEL: func.func @sqsum(
@@ -17,7 +16,10 @@
1716
// CHECK-DAG: %[[POS_BUF:.*]] = sparse_tensor.positions %{{.*}} {level = 0 : index} : tensor<?x?x?x?xi32, #sparse> to memref<?xindex>
1817
// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C0]]] : memref<?xindex>
1918
// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_BUF]]{{\[}}%[[C1]]] : memref<?xindex>
19+
// CHECK: %[[VAL_BUF:.*]] = sparse_tensor.values %{{.*}} : tensor<?x?x?x?xi32, #sparse> to memref<?xi32>
2020
// CHECK: %[[SQ_SUM:.*]] = scf.for %[[POS:.*]] = %[[POS_LO]] to %[[POS_HI]] step %[[C1]] {{.*}} {
21+
// CHECK: %[[VAL:.*]] = memref.load %[[VAL_BUF]]{{\[}}%[[POS]]] : memref<?xi32>
22+
// CHECK: %[[MUL:.*]] = arith.muli %[[VAL]], %[[VAL]] : i32
2123
// CHECK: %[[SUM:.*]] = arith.addi
2224
// CHECK: scf.yield %[[SUM]] : i32
2325
// CHECK: }

0 commit comments

Comments
 (0)