Skip to content

Commit c5c7bb1

Browse files
author
Peiming Liu
committed
[mlir][sparse] implement lowering rules for ExtractIterSpaceOp.
1 parent 6adcdb6 commit c5c7bb1

File tree

10 files changed

+356
-65
lines changed

10 files changed

+356
-65
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,10 @@ struct LevelType {
357357
return hasSparseSemantic();
358358
}
359359

360+
constexpr unsigned getNumBuffer() const {
361+
return hasDenseSemantic() ? 0 : (isWithPosLT() ? 2 : 1);
362+
}
363+
360364
std::string toMLIRString() const {
361365
std::string lvlStr = toFormatString(getLvlFmt());
362366
std::string propStr = "";

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/IR/PatternMatch.h"
1717
#include "mlir/Pass/Pass.h"
1818
#include "mlir/Transforms/DialectConversion.h"
19+
#include "mlir/Transforms/OneToNTypeConversion.h"
1920

2021
//===----------------------------------------------------------------------===//
2122
// Include the generated pass header (which needs some early definitions).
@@ -143,6 +144,20 @@ void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
143144

144145
std::unique_ptr<Pass> createLowerForeachToSCFPass();
145146

147+
//===----------------------------------------------------------------------===//
148+
// The LowerSparseIterationToSCF pass.
149+
//===----------------------------------------------------------------------===//
150+
151+
/// Type converter for iter_space and iterator.
152+
struct SparseIterationTypeConverter : public OneToNTypeConverter {
153+
SparseIterationTypeConverter();
154+
};
155+
156+
void populateLowerSparseIterationToSCFPatterns(TypeConverter &converter,
157+
RewritePatternSet &patterns);
158+
159+
std::unique_ptr<Pass> createLowerSparseIterationToSCFPass();
160+
146161
//===----------------------------------------------------------------------===//
147162
// The SparseTensorConversion pass.
148163
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,4 +479,19 @@ def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
479479
];
480480
}
481481

482+
def LowerSparseIterationToSCF : Pass<"lower-sparse-iteration-to-scf", "func::FuncOp"> {
483+
let summary = "lower sparse_tensor.iterate/coiterate into scf loops";
484+
let description = [{
485+
This pass lowers `sparse_tensor.iterate` operations into `scf.for/while` operations.
486+
The pass is not yet stablized.
487+
}];
488+
let constructor = "mlir::createLowerSparseIterationToSCFPass()";
489+
let dependentDialects = [
490+
"memref::MemRefDialect",
491+
"scf::SCFDialect",
492+
"sparse_tensor::SparseTensorDialect",
493+
];
494+
}
495+
496+
482497
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
44
SparseAssembler.cpp
55
SparseBufferRewriting.cpp
66
SparseGPUCodegen.cpp
7+
SparseIterationToScf.cpp
78
SparseReinterpretMap.cpp
89
SparseStorageSpecifierToLLVM.cpp
910
SparseSpaceCollapse.cpp
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
2+
#include "Utils/CodegenUtils.h"
3+
#include "Utils/SparseTensorIterator.h"
4+
5+
#include "mlir/Dialect/SCF/IR/SCF.h"
6+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
7+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
8+
#include "mlir/Transforms/OneToNTypeConversion.h"
9+
10+
using namespace mlir;
11+
using namespace mlir::sparse_tensor;
12+
13+
void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
14+
SmallVectorImpl<Type> &fields) {
15+
// Position and coordinate buffer in the sparse structure.
16+
if (enc.getLvlType(lvl).isWithPosLT())
17+
fields.push_back(enc.getPosMemRefType());
18+
if (enc.getLvlType(lvl).isWithCrdLT())
19+
fields.push_back(enc.getCrdMemRefType());
20+
// One index for shape bound (result from lvlOp)
21+
fields.push_back(IndexType::get(enc.getContext()));
22+
}
23+
24+
static std::optional<LogicalResult>
25+
convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
26+
27+
auto idxTp = IndexType::get(itSp.getContext());
28+
for (Level l = itSp.getLoLvl(); l < itSp.getHiLvl(); l++)
29+
convertLevelType(itSp.getEncoding(), l, fields);
30+
31+
// Two indices for lower and upper bound (we only need one pair for the last
32+
// iteration space).
33+
fields.append({idxTp, idxTp});
34+
return success();
35+
}
36+
37+
namespace {
38+
39+
/// Sparse codegen rule for number of entries operator.
40+
class ExtractIterSpaceConverter
41+
: public OneToNOpConversionPattern<ExtractIterSpaceOp> {
42+
public:
43+
using OneToNOpConversionPattern::OneToNOpConversionPattern;
44+
LogicalResult
45+
matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
46+
OneToNPatternRewriter &rewriter) const override {
47+
Location loc = op.getLoc();
48+
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
49+
50+
// Construct the iteration space.
51+
SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
52+
op.getLvlRange(), adaptor.getParentIter());
53+
54+
SmallVector<Value> result = space.toValues();
55+
rewriter.replaceOp(op, result, resultMapping);
56+
return success();
57+
}
58+
};
59+
60+
} // namespace
61+
62+
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
63+
addConversion([](Type type) { return type; });
64+
addConversion(convertIterSpaceType);
65+
66+
addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp,
67+
ValueRange inputs,
68+
Location loc) -> std::optional<Value> {
69+
return builder
70+
.create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs)
71+
.getResult(0);
72+
});
73+
}
74+
75+
void mlir::populateLowerSparseIterationToSCFPatterns(
76+
TypeConverter &converter, RewritePatternSet &patterns) {
77+
patterns.add<ExtractIterSpaceConverter>(converter, patterns.getContext());
78+
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace mlir {
2626
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
2727
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
2828
#define GEN_PASS_DEF_SPARSIFICATIONPASS
29+
#define GEN_PASS_DEF_LOWERSPARSEITERATIONTOSCF
2930
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
3031
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
3132
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
@@ -153,10 +154,34 @@ struct LowerForeachToSCFPass
153154
auto *ctx = &getContext();
154155
RewritePatternSet patterns(ctx);
155156
populateLowerForeachToSCFPatterns(patterns);
157+
156158
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
157159
}
158160
};
159161

162+
struct LowerSparseIterationToSCFPass
163+
: public impl::LowerSparseIterationToSCFBase<
164+
LowerSparseIterationToSCFPass> {
165+
LowerSparseIterationToSCFPass() = default;
166+
LowerSparseIterationToSCFPass(const LowerSparseIterationToSCFPass &) =
167+
default;
168+
169+
void runOnOperation() override {
170+
auto *ctx = &getContext();
171+
RewritePatternSet patterns(ctx);
172+
SparseIterationTypeConverter converter;
173+
ConversionTarget target(*ctx);
174+
175+
// The actual conversion.
176+
target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
177+
populateLowerSparseIterationToSCFPatterns(converter, patterns);
178+
179+
if (failed(applyPartialOneToNConversion(getOperation(), converter,
180+
std::move(patterns))))
181+
signalPassFailure();
182+
}
183+
};
184+
160185
struct SparseTensorConversionPass
161186
: public impl::SparseTensorConversionPassBase<SparseTensorConversionPass> {
162187
SparseTensorConversionPass() = default;
@@ -439,6 +464,10 @@ std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
439464
return std::make_unique<LowerForeachToSCFPass>();
440465
}
441466

467+
std::unique_ptr<Pass> mlir::createLowerSparseIterationToSCFPass() {
468+
return std::make_unique<LowerSparseIterationToSCFPass>();
469+
}
470+
442471
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
443472
return std::make_unique<SparseTensorConversionPass>();
444473
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ class LoopEmitter {
222222
///
223223
SmallVector<Value> getValPosits(TensorId tid) const {
224224
SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
225-
Value lastLvlPos = iters[tid].back().back()->getCurPosition().first;
225+
Value lastLvlPos = iters[tid].back().back()->getCurPosition().front();
226226
batchCrds.push_back(lastLvlPos);
227227
return batchCrds;
228228
};

0 commit comments

Comments
 (0)