@@ -34,20 +34,6 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
34
34
return success ();
35
35
}
36
36
37
- static std::optional<LogicalResult>
38
- convertIteratorType (IteratorType itTp, SmallVectorImpl<Type> &fields) {
39
- // The actually Iterator Values (that are updated every iteration).
40
- auto idxTp = IndexType::get (itTp.getContext ());
41
- // TODO: handle batch dimension.
42
- assert (itTp.getEncoding ().getBatchLvlRank () == 0 );
43
- if (!itTp.isUnique ()) {
44
- // Segment high for non-unique iterator.
45
- fields.push_back (idxTp);
46
- }
47
- fields.push_back (idxTp);
48
- return success ();
49
- }
50
-
51
37
namespace {
52
38
53
39
// / Sparse codegen rule for number of entries operator.
@@ -71,114 +57,10 @@ class ExtractIterSpaceConverter
71
57
}
72
58
};
73
59
74
- class SparseIterateOpConverter : public OneToNOpConversionPattern <IterateOp> {
75
- public:
76
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
77
- LogicalResult
78
- matchAndRewrite (IterateOp op, OpAdaptor adaptor,
79
- OneToNPatternRewriter &rewriter) const override {
80
- if (!op.getCrdUsedLvls ().empty ())
81
- return rewriter.notifyMatchFailure (
82
- op, " non-empty coordinates list not implemented." );
83
-
84
- Location loc = op.getLoc ();
85
-
86
- auto iterSpace = SparseIterationSpace::fromValues (
87
- op.getIterSpace ().getType (), adaptor.getIterSpace (), 0 );
88
-
89
- std::unique_ptr<SparseIterator> it =
90
- iterSpace.extractIterator (rewriter, loc);
91
-
92
- if (it->iteratableByFor ()) {
93
- auto [lo, hi] = it->genForCond (rewriter, loc);
94
- Value step = constantIndex (rewriter, loc, 1 );
95
- SmallVector<Value> ivs;
96
- for (ValueRange inits : adaptor.getInitArgs ())
97
- llvm::append_range (ivs, inits);
98
- scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, lo, hi, step, ivs);
99
-
100
- Block *loopBody = op.getBody ();
101
- OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
102
- if (failed (typeConverter->convertSignatureArgs (
103
- loopBody->getArgumentTypes (), bodyTypeMapping)))
104
- return failure ();
105
- rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
106
-
107
- forOp.getBody ()->erase ();
108
- Region &dstRegion = forOp.getRegion ();
109
- rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
110
-
111
- auto yieldOp =
112
- llvm::cast<sparse_tensor::YieldOp>(forOp.getBody ()->getTerminator ());
113
-
114
- rewriter.setInsertionPointToEnd (forOp.getBody ());
115
- // replace sparse_tensor.yield with scf.yield.
116
- rewriter.create <scf::YieldOp>(loc, yieldOp.getResults ());
117
- yieldOp.erase ();
118
-
119
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
120
- rewriter.replaceOp (op, forOp.getResults (), resultMapping);
121
- } else {
122
- SmallVector<Value> ivs;
123
- llvm::append_range (ivs, it->getCursor ());
124
- for (ValueRange inits : adaptor.getInitArgs ())
125
- llvm::append_range (ivs, inits);
126
-
127
- assert (llvm::all_of (ivs, [](Value v) { return v != nullptr ; }));
128
-
129
- TypeRange types = ValueRange (ivs).getTypes ();
130
- auto whileOp = rewriter.create <scf::WhileOp>(loc, types, ivs);
131
- SmallVector<Location> l (types.size (), op.getIterator ().getLoc ());
132
-
133
- // Generates loop conditions.
134
- Block *before = rewriter.createBlock (&whileOp.getBefore (), {}, types, l);
135
- rewriter.setInsertionPointToStart (before);
136
- ValueRange bArgs = before->getArguments ();
137
- auto [whileCond, remArgs] = it->genWhileCond (rewriter, loc, bArgs);
138
- assert (remArgs.size () == adaptor.getInitArgs ().size ());
139
- rewriter.create <scf::ConditionOp>(loc, whileCond, before->getArguments ());
140
-
141
- // Generates loop body.
142
- Block *loopBody = op.getBody ();
143
- OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
144
- if (failed (typeConverter->convertSignatureArgs (
145
- loopBody->getArgumentTypes (), bodyTypeMapping)))
146
- return failure ();
147
- rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
148
-
149
- Region &dstRegion = whileOp.getAfter ();
150
- // TODO: handle uses of coordinate!
151
- rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
152
- ValueRange aArgs = whileOp.getAfterArguments ();
153
- auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
154
- whileOp.getAfterBody ()->getTerminator ());
155
-
156
- rewriter.setInsertionPointToEnd (whileOp.getAfterBody ());
157
-
158
- aArgs = it->linkNewScope (aArgs);
159
- ValueRange nx = it->forward (rewriter, loc);
160
- SmallVector<Value> yields;
161
- llvm::append_range (yields, nx);
162
- llvm::append_range (yields, yieldOp.getResults ());
163
-
164
- // replace sparse_tensor.yield with scf.yield.
165
- yieldOp->erase ();
166
- rewriter.create <scf::YieldOp>(loc, yields);
167
-
168
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
169
- rewriter.replaceOp (
170
- op, whileOp.getResults ().drop_front (it->getCursor ().size ()),
171
- resultMapping);
172
- }
173
- return success ();
174
- }
175
- };
176
-
177
60
} // namespace
178
61
179
62
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter () {
180
63
addConversion ([](Type type) { return type; });
181
- addConversion (convertIteratorType);
182
64
addConversion (convertIterSpaceType);
183
65
184
66
addSourceMaterialization ([](OpBuilder &builder, IterSpaceType spTp,
@@ -192,6 +74,5 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
192
74
193
75
void mlir::populateLowerSparseIterationToSCFPatterns (
194
76
TypeConverter &converter, RewritePatternSet &patterns) {
195
- patterns.add <ExtractIterSpaceConverter, SparseIterateOpConverter>(
196
- converter, patterns.getContext ());
77
+ patterns.add <ExtractIterSpaceConverter>(converter, patterns.getContext ());
197
78
}
0 commit comments