@@ -34,6 +34,19 @@ convertIterSpaceType(IterSpaceType itSp, SmallVectorImpl<Type> &fields) {
34
34
return success ();
35
35
}
36
36
37
+ static std::optional<LogicalResult>
38
+ convertIteratorType (IteratorType itTp, SmallVectorImpl<Type> &fields) {
39
+ // The actually Iterator Values (that are updated every iteration).
40
+ auto idxTp = IndexType::get (itTp.getContext ());
41
+ // TODO: This assumes there is no batch dimenstion in the sparse tensor.
42
+ if (!itTp.isUnique ()) {
43
+ // Segment high for non-unqiue iterator.
44
+ fields.push_back (idxTp);
45
+ }
46
+ fields.push_back (idxTp);
47
+ return success ();
48
+ }
49
+
37
50
namespace {
38
51
39
52
// / Sparse codegen rule for number of entries operator.
@@ -57,10 +70,113 @@ class ExtractIterSpaceConverter
57
70
}
58
71
};
59
72
73
+ class SparseIterateOpConverter : public OneToNOpConversionPattern <IterateOp> {
74
+ public:
75
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
76
+ LogicalResult
77
+ matchAndRewrite (IterateOp op, OpAdaptor adaptor,
78
+ OneToNPatternRewriter &rewriter) const override {
79
+ if (!op.getCrdUsedLvls ().empty ())
80
+ llvm_unreachable (" Not implemented." );
81
+
82
+ Location loc = op.getLoc ();
83
+
84
+ auto iterSpace = SparseIterationSpace::fromValues (
85
+ op.getIterSpace ().getType (), adaptor.getIterSpace (), 0 );
86
+
87
+ std::unique_ptr<SparseIterator> it =
88
+ iterSpace.extractIterator (rewriter, loc);
89
+
90
+ if (it->iteratableByFor ()) {
91
+ auto [lo, hi] = it->genForCond (rewriter, loc);
92
+ Value step = constantIndex (rewriter, loc, 1 );
93
+ SmallVector<Value> ivs;
94
+ for (ValueRange inits : adaptor.getInitArgs ())
95
+ llvm::append_range (ivs, inits);
96
+ scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, lo, hi, step, ivs);
97
+
98
+ Block *loopBody = op.getBody ();
99
+ OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
100
+ if (failed (typeConverter->convertSignatureArgs (
101
+ loopBody->getArgumentTypes (), bodyTypeMapping)))
102
+ return failure ();
103
+ rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
104
+
105
+ forOp.getBody ()->erase ();
106
+ Region &dstRegion = forOp.getRegion ();
107
+ rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
108
+
109
+ auto yieldOp =
110
+ llvm::cast<sparse_tensor::YieldOp>(forOp.getBody ()->getTerminator ());
111
+
112
+ rewriter.setInsertionPointToEnd (forOp.getBody ());
113
+ // replace sparse_tensor.yield with scf.yield.
114
+ rewriter.create <scf::YieldOp>(loc, yieldOp.getResults ());
115
+ yieldOp.erase ();
116
+
117
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
118
+ rewriter.replaceOp (op, forOp.getResults (), resultMapping);
119
+ } else {
120
+ SmallVector<Value> ivs;
121
+ llvm::append_range (ivs, it->getCursor ());
122
+ for (ValueRange inits : adaptor.getInitArgs ())
123
+ llvm::append_range (ivs, inits);
124
+
125
+ assert (llvm::all_of (ivs, [](Value v) { return v != nullptr ; }));
126
+
127
+ TypeRange types = ValueRange (ivs).getTypes ();
128
+ auto whileOp = rewriter.create <scf::WhileOp>(loc, types, ivs);
129
+ SmallVector<Location> l (types.size (), op.getIterator ().getLoc ());
130
+
131
+ // Generates loop conditions.
132
+ Block *before = rewriter.createBlock (&whileOp.getBefore (), {}, types, l);
133
+ rewriter.setInsertionPointToStart (before);
134
+ ValueRange bArgs = before->getArguments ();
135
+ auto [whileCond, remArgs] = it->genWhileCond (rewriter, loc, bArgs);
136
+ assert (remArgs.size () == adaptor.getInitArgs ().size ());
137
+ rewriter.create <scf::ConditionOp>(loc, whileCond, before->getArguments ());
138
+
139
+ // Generates loop body.
140
+ Block *loopBody = op.getBody ();
141
+ OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
142
+ if (failed (typeConverter->convertSignatureArgs (
143
+ loopBody->getArgumentTypes (), bodyTypeMapping)))
144
+ return failure ();
145
+ rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
146
+
147
+ Region &dstRegion = whileOp.getAfter ();
148
+ // TODO: handle uses of coordinate!
149
+ rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
150
+ ValueRange aArgs = whileOp.getAfterArguments ();
151
+ auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
152
+ whileOp.getAfterBody ()->getTerminator ());
153
+
154
+ rewriter.setInsertionPointToEnd (whileOp.getAfterBody ());
155
+
156
+ aArgs = it->linkNewScope (aArgs);
157
+ ValueRange nx = it->forward (rewriter, loc);
158
+ SmallVector<Value> yields;
159
+ llvm::append_range (yields, nx);
160
+ llvm::append_range (yields, yieldOp.getResults ());
161
+
162
+ // replace sparse_tensor.yield with scf.yield.
163
+ yieldOp->erase ();
164
+ rewriter.create <scf::YieldOp>(loc, yields);
165
+
166
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
167
+ rewriter.replaceOp (
168
+ op, whileOp.getResults ().drop_front (it->getCursor ().size ()),
169
+ resultMapping);
170
+ }
171
+ return success ();
172
+ }
173
+ };
174
+
60
175
} // namespace
61
176
62
177
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter () {
63
178
addConversion ([](Type type) { return type; });
179
+ addConversion (convertIteratorType);
64
180
addConversion (convertIterSpaceType);
65
181
66
182
addSourceMaterialization ([](OpBuilder &builder, IterSpaceType spTp,
@@ -74,5 +190,6 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
74
190
75
191
void mlir::populateLowerSparseIterationToSCFPatterns (
76
192
TypeConverter &converter, RewritePatternSet &patterns) {
77
- patterns.add <ExtractIterSpaceConverter>(converter, patterns.getContext ());
193
+ patterns.add <ExtractIterSpaceConverter, SparseIterateOpConverter>(
194
+ converter, patterns.getContext ());
78
195
}
0 commit comments