Skip to content

Commit 6b51c7a

Browse files
author
Peiming Liu
authored
Revert "[mlir][sparse] implement lowering rules for IterateOp. (#95286)"
This reverts commit 3a2e442.
1 parent 1539989 commit 6b51c7a

File tree

4 files changed

+17
-224
lines changed

4 files changed

+17
-224
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

Lines changed: 1 addition & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,6 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
3434
return success();
3535
}
3636

37-
static std::optional<LogicalResult>
38-
convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
39-
// The actually Iterator Values (that are updated every iteration).
40-
auto idxTp = IndexType::get(itTp.getContext());
41-
// TODO: handle batch dimension.
42-
assert(itTp.getEncoding().getBatchLvlRank() == 0);
43-
if (!itTp.isUnique()) {
44-
// Segment high for non-unique iterator.
45-
fields.push_back(idxTp);
46-
}
47-
fields.push_back(idxTp);
48-
return success();
49-
}
50-
5137
namespace {
5238

5339
/// Sparse codegen rule for number of entries operator.
@@ -71,114 +57,10 @@ class ExtractIterSpaceConverter
7157
}
7258
};
7359

74-
class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
75-
public:
76-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
77-
LogicalResult
78-
matchAndRewrite(IterateOp op, OpAdaptor adaptor,
79-
OneToNPatternRewriter &rewriter) const override {
80-
if (!op.getCrdUsedLvls().empty())
81-
return rewriter.notifyMatchFailure(
82-
op, "non-empty coordinates list not implemented.");
83-
84-
Location loc = op.getLoc();
85-
86-
auto iterSpace = SparseIterationSpace::fromValues(
87-
op.getIterSpace().getType(), adaptor.getIterSpace(), 0);
88-
89-
std::unique_ptr<SparseIterator> it =
90-
iterSpace.extractIterator(rewriter, loc);
91-
92-
if (it->iteratableByFor()) {
93-
auto [lo, hi] = it->genForCond(rewriter, loc);
94-
Value step = constantIndex(rewriter, loc, 1);
95-
SmallVector<Value> ivs;
96-
for (ValueRange inits : adaptor.getInitArgs())
97-
llvm::append_range(ivs, inits);
98-
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
99-
100-
Block *loopBody = op.getBody();
101-
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
102-
if (failed(typeConverter->convertSignatureArgs(
103-
loopBody->getArgumentTypes(), bodyTypeMapping)))
104-
return failure();
105-
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
106-
107-
forOp.getBody()->erase();
108-
Region &dstRegion = forOp.getRegion();
109-
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
110-
111-
auto yieldOp =
112-
llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
113-
114-
rewriter.setInsertionPointToEnd(forOp.getBody());
115-
// replace sparse_tensor.yield with scf.yield.
116-
rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
117-
yieldOp.erase();
118-
119-
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
120-
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
121-
} else {
122-
SmallVector<Value> ivs;
123-
llvm::append_range(ivs, it->getCursor());
124-
for (ValueRange inits : adaptor.getInitArgs())
125-
llvm::append_range(ivs, inits);
126-
127-
assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
128-
129-
TypeRange types = ValueRange(ivs).getTypes();
130-
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
131-
SmallVector<Location> l(types.size(), op.getIterator().getLoc());
132-
133-
// Generates loop conditions.
134-
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
135-
rewriter.setInsertionPointToStart(before);
136-
ValueRange bArgs = before->getArguments();
137-
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
138-
assert(remArgs.size() == adaptor.getInitArgs().size());
139-
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
140-
141-
// Generates loop body.
142-
Block *loopBody = op.getBody();
143-
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
144-
if (failed(typeConverter->convertSignatureArgs(
145-
loopBody->getArgumentTypes(), bodyTypeMapping)))
146-
return failure();
147-
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
148-
149-
Region &dstRegion = whileOp.getAfter();
150-
// TODO: handle uses of coordinate!
151-
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
152-
ValueRange aArgs = whileOp.getAfterArguments();
153-
auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
154-
whileOp.getAfterBody()->getTerminator());
155-
156-
rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
157-
158-
aArgs = it->linkNewScope(aArgs);
159-
ValueRange nx = it->forward(rewriter, loc);
160-
SmallVector<Value> yields;
161-
llvm::append_range(yields, nx);
162-
llvm::append_range(yields, yieldOp.getResults());
163-
164-
// replace sparse_tensor.yield with scf.yield.
165-
yieldOp->erase();
166-
rewriter.create<scf::YieldOp>(loc, yields);
167-
168-
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
169-
rewriter.replaceOp(
170-
op, whileOp.getResults().drop_front(it->getCursor().size()),
171-
resultMapping);
172-
}
173-
return success();
174-
}
175-
};
176-
17760
} // namespace
17861

17962
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
18063
addConversion([](Type type) { return type; });
181-
addConversion(convertIteratorType);
18264
addConversion(convertIterSpaceType);
18365

18466
addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
@@ -192,6 +74,5 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
19274

19375
void mlir::populateLowerSparseIterationToSCFPatterns(
19476
TypeConverter &converter, RewritePatternSet &patterns) {
195-
patterns.add<ExtractIterSpaceConverter, SparseIterateOpConverter>(
196-
converter, patterns.getContext());
77+
patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
19778
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.cpp

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,6 @@ class TrivialIterator : public ConcreteIterator {
331331
TrivialIterator(const SparseTensorLevel &stl)
332332
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1) {}
333333

334-
TrivialIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
335-
Value posLo, Value posHi)
336-
: ConcreteIterator(stl, IterKind::kTrivial, /*itValCnt=*/1), posLo(posLo),
337-
posHi(posHi) {
338-
seek(posLo);
339-
}
340-
341334
std::string getDebugInterfacePrefix() const override {
342335
return std::string("trivial<") + stl.toString() + ">";
343336
}
@@ -427,14 +420,6 @@ class DedupIterator : public ConcreteIterator {
427420
: ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2) {
428421
assert(!stl.isUnique());
429422
}
430-
431-
DedupIterator(OpBuilder &b, Location l, const SparseTensorLevel &stl,
432-
Value posLo, Value posHi)
433-
: ConcreteIterator(stl, IterKind::kDedup, /*itValCnt=*/2), posHi(posHi) {
434-
assert(!stl.isUnique());
435-
seek({posLo, genSegmentHigh(b, l, posLo)});
436-
}
437-
438423
// For LLVM-style RTTI.
439424
static bool classof(const SparseIterator *from) {
440425
return from->kind == IterKind::kDedup;
@@ -1547,11 +1532,6 @@ SparseIterationSpace mlir::sparse_tensor::SparseIterationSpace::fromValues(
15471532
return space;
15481533
}
15491534

1550-
std::unique_ptr<SparseIterator>
1551-
SparseIterationSpace::extractIterator(OpBuilder &b, Location l) const {
1552-
return makeSimpleIterator(b, l, *this);
1553-
}
1554-
15551535
//===----------------------------------------------------------------------===//
15561536
// SparseIterator factory functions.
15571537
//===----------------------------------------------------------------------===//
@@ -1610,26 +1590,6 @@ sparse_tensor::makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl,
16101590
return std::make_pair(std::move(stl), std::move(it));
16111591
}
16121592

1613-
std::unique_ptr<SparseIterator>
1614-
sparse_tensor::makeSimpleIterator(OpBuilder &b, Location l,
1615-
const SparseIterationSpace &iterSpace) {
1616-
// assert(iterSpace.getSpaceDim() == 1);
1617-
std::unique_ptr<SparseIterator> ret;
1618-
if (!iterSpace.isUnique()) {
1619-
// We always dedupliate the non-unique level, but we should optimize it away
1620-
// if possible.
1621-
ret = std::make_unique<DedupIterator>(b, l, iterSpace.getLastLvl(),
1622-
iterSpace.getBoundLo(),
1623-
iterSpace.getBoundHi());
1624-
} else {
1625-
ret = std::make_unique<TrivialIterator>(b, l, iterSpace.getLastLvl(),
1626-
iterSpace.getBoundLo(),
1627-
iterSpace.getBoundHi());
1628-
}
1629-
ret->setSparseEmitStrategy(SparseEmitStrategy::kFunctional);
1630-
return ret;
1631-
}
1632-
16331593
std::unique_ptr<SparseIterator>
16341594
sparse_tensor::makeSimpleIterator(const SparseTensorLevel &stl,
16351595
SparseEmitStrategy strategy) {

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorIterator.h

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,6 @@ class SparseIterationSpace {
132132
Value getBoundLo() const { return bound.first; }
133133
Value getBoundHi() const { return bound.second; }
134134

135-
// Extract an iterator to iterate over the sparse iteration space.
136-
std::unique_ptr<SparseIterator> extractIterator(OpBuilder &b,
137-
Location l) const;
138-
139135
private:
140136
SmallVector<std::unique_ptr<SparseTensorLevel>> lvls;
141137
std::pair<Value, Value> bound;
@@ -196,13 +192,6 @@ class SparseIterator {
196192
crd = nullptr;
197193
}
198194

199-
// Reconstructs a iteration space directly from the provided ValueRange.
200-
static std::unique_ptr<SparseIterator>
201-
fromValues(IteratorType dstTp, ValueRange values, unsigned tid);
202-
203-
// The inverse operation of `fromValues`.
204-
SmallVector<Value> toValues() const { llvm_unreachable("Not implemented"); }
205-
206195
//
207196
// Iterator properties.
208197
//
@@ -356,21 +345,12 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &b,
356345
unsigned tid,
357346
Level lvl);
358347

359-
/// Helper function to create a TensorLevel object from given ValueRange.
348+
/// Helper function to create a TensorLevel object from given `tensor`.
360349
std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(LevelType lt, Value sz,
361350
ValueRange buffers,
362351
unsigned tid, Level l);
363-
364-
/// Helper function to create a simple SparseIterator object that iterate
365-
/// over the entire iteration space.
366-
std::unique_ptr<SparseIterator>
367-
makeSimpleIterator(OpBuilder &b, Location l,
368-
const SparseIterationSpace &iterSpace);
369-
370-
/// Helper function to create a simple SparseIterator object that iterate
371-
/// over the sparse tensor level.
372-
/// TODO: switch to `SparseIterationSpace` (which support N-D iterator) when
373-
/// feature complete.
352+
/// Helper function to create a simple SparseIterator object that iterates
353+
/// over the SparseTensorLevel.
374354
std::unique_ptr<SparseIterator> makeSimpleIterator(
375355
const SparseTensorLevel &stl,
376356
SparseEmitStrategy strategy = SparseEmitStrategy::kFunctional);
Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
// RUN: mlir-opt %s --lower-sparse-iteration-to-scf | FileCheck %s
2-
// RUN: mlir-opt %s --sparse-space-collapse --lower-sparse-iteration-to-scf | FileCheck %s --check-prefix COLLAPSED
32

43
#COO = #sparse_tensor.encoding<{
54
map = (i, j) -> (
@@ -8,44 +7,17 @@
87
)
98
}>
109

11-
// CHECK-LABEL: @sparse_iteration_to_scf
12-
// // deduplication
13-
// CHECK: scf.while {{.*}} {
14-
// CHECK: } do {
15-
// CHECK: }
16-
// CHECK: scf.while {{.*}} {
17-
// CHECK: } do {
18-
// // actual computation
19-
// CHECK: scf.for {{.*}} {
20-
// CHECK: arith.addi
21-
// CHECK: }
22-
// // deduplication
23-
// CHECK: scf.while {{.*}} {
24-
// CHECK: } do {
25-
// CHECK: }
26-
// CHECK: scf.yield
27-
// CHECK: }
28-
// CHECK: return
29-
30-
// COLLAPSED-LABEL: @sparse_iteration_to_scf
31-
// COLLAPSED: %[[RET:.*]] = scf.for {{.*}} {
32-
// COLLAPSED: %[[VAL:.*]] = arith.addi
33-
// COLLAPSED: scf.yield %[[VAL]] : index
34-
// COLLAPSED: }
35-
// COLLAPSED: return %[[RET]] : index
36-
func.func @sparse_iteration_to_scf(%sp : tensor<4x8xf32, #COO>) -> index {
37-
%i = arith.constant 0 : index
38-
%c1 = arith.constant 1 : index
39-
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
40-
: tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
41-
%r1 = sparse_tensor.iterate %it1 in %l1 iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
42-
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
43-
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1> -> !sparse_tensor.iter_space<#COO, lvls = 1>
44-
%r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
45-
%k = arith.addi %inner, %c1 : index
46-
sparse_tensor.yield %k : index
47-
}
48-
sparse_tensor.yield %r2 : index
49-
}
50-
return %r1 : index
10+
// CHECK-LABEL: func.func @sparse_1D_space(
11+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32, #sparse{{[0-9]*}}>) -> !sparse_tensor.iter_space<#sparse{{[0-9]*}}, lvls = 0> {
12+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
13+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
14+
// CHECK-DAG: %[[LVL_SIZE:.*]] = sparse_tensor.lvl %[[VAL_0]], %[[C0]] : tensor<?x?xf32, #sparse{{[0-9]*}}>
15+
// CHECK: %[[POS_MEM:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
16+
// CHECK: %[[CRD_MEM:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<?x?xf32, #sparse{{[0-9]*}}> to memref<?xindex>
17+
// CHECK: %[[POS_LO:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C0]]] : memref<?xindex>
18+
// CHECK: %[[POS_HI:.*]] = memref.load %[[POS_MEM]]{{\[}}%[[C1]]] : memref<?xindex>
19+
// CHECK: %[[ITER_SPACE:.*]] = builtin.unrealized_conversion_cast %[[POS_MEM]], %[[CRD_MEM]], %[[LVL_SIZE]], %[[POS_LO]], %[[POS_HI]]
20+
func.func @sparse_1D_space(%sp : tensor<?x?xf32, #COO>) -> !sparse_tensor.iter_space<#COO, lvls = 0> {
21+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
22+
return %l1 : !sparse_tensor.iter_space<#COO, lvls = 0>
5123
}

0 commit comments

Comments
 (0)