@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
110
110
}));
111
111
}
112
112
113
- // / Instantiate the tiled implementation of the operation.
113
+ // Instantiate the tiled implementation of the operation.
114
114
FailureOr<TilingResult>
115
115
getTiledImplementation (Operation *op, OpBuilder &b,
116
116
ArrayRef<OpFoldResult> offsets,
@@ -132,66 +132,14 @@ struct LinalgOpTilingInterface
132
132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
133
133
}
134
134
135
- void
136
- getMappedOffsetAndSize (LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
137
- ArrayRef<OpFoldResult> offsets,
138
- ArrayRef<OpFoldResult> sizes,
139
- SmallVectorImpl<OpFoldResult> &mappedOffsets,
140
- SmallVectorImpl<OpFoldResult> &mappedSizes) const {
141
- unsigned numLoops = linalgOp.getNumLoops ();
142
- auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation ());
143
- mappedOffsets.resize (numLoops);
144
- mappedSizes.resize (numLoops);
145
- if (!indexingMap.isPermutation ()) {
146
- SmallVector<Range> iterationDomain =
147
- tilingInterfaceOp.getIterationDomain (b);
148
- for (const auto &&[index, value] : llvm::enumerate (iterationDomain)) {
149
- mappedOffsets[index] = value.offset ;
150
- mappedSizes[index] = value.size ;
151
- }
152
- }
153
- for (const auto &&[index, value] :
154
- llvm::enumerate (indexingMap.getResults ())) {
155
- unsigned dimPosition = cast<AffineDimExpr>(value).getPosition ();
156
- mappedOffsets[dimPosition] = offsets[index];
157
- mappedSizes[dimPosition] = sizes[index];
158
- }
159
- }
160
-
161
- // / Return the details of the output tile generated by the tiled
162
- // / implementation.
163
- LogicalResult getIterationDomainTileFromOperandTile (
164
- Operation *op, OpBuilder &b, unsigned operandNumber,
165
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
166
- SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
167
- SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
168
- auto linalgOp = cast<LinalgOp>(op);
169
-
170
- // Check that the indexing map used for the operand is a projected
171
- // permutation. This could be relaxed with a more general approach that can
172
- // map the offsets and sizes from the operand to iteration space tiles
173
- // (filling in full extent for dimensions not used to access the result).
174
- AffineMap indexingMap =
175
- linalgOp.getMatchingIndexingMap (&op->getOpOperand (operandNumber));
176
- if (!indexingMap.isProjectedPermutation ()) {
177
- return emitError (op->getLoc (),
178
- " unhandled get iter domain position when operand is not "
179
- " accessed using a permuted projection" );
180
- }
181
-
182
- getMappedOffsetAndSize (linalgOp, b, indexingMap, offsets, sizes,
183
- iterDomainOffsets, iterDomainSizes);
184
- return success ();
185
- }
186
-
187
- // / Return the details of the output tile generated by the tiled
188
- // / implementation.
135
+ // Return the details of the output tile generated by the tiled
136
+ // implementation.
189
137
LogicalResult
190
138
getResultTilePosition (Operation *op, OpBuilder &b, unsigned resultNumber,
191
139
ArrayRef<OpFoldResult> offsets,
192
140
ArrayRef<OpFoldResult> sizes,
193
- SmallVectorImpl <OpFoldResult> &resultOffsets,
194
- SmallVectorImpl <OpFoldResult> &resultSizes) const {
141
+ SmallVector <OpFoldResult> &resultOffsets,
142
+ SmallVector <OpFoldResult> &resultSizes) const {
195
143
Location loc = op->getLoc ();
196
144
LinalgOp linalgOp = cast<LinalgOp>(op);
197
145
@@ -212,21 +160,6 @@ struct LinalgOpTilingInterface
212
160
return success ();
213
161
}
214
162
215
- FailureOr<TilingResult> getTiledImplementationFromOperandTile (
216
- Operation *op, OpBuilder &b, unsigned operandNumber,
217
- ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
218
- SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
219
- auto tilingInterfaceOp = cast<TilingInterface>(op);
220
- if (failed (tilingInterfaceOp.getIterationDomainTileFromOperandTile (
221
- b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
222
- return emitError (
223
- op->getLoc (),
224
- " unable to obtain the iter domain position of the operation." );
225
- }
226
- return tilingInterfaceOp.getTiledImplementation (b, mappedOffsets,
227
- mappedSizes);
228
- }
229
-
230
163
FailureOr<TilingResult>
231
164
generateResultTileValue (Operation *op, OpBuilder &b, unsigned resultNumber,
232
165
ArrayRef<OpFoldResult> offsets,
@@ -244,16 +177,29 @@ struct LinalgOpTilingInterface
244
177
" unhandled tiled implementation generation when result is not "
245
178
" accessed using a permuted projection" );
246
179
}
247
- SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
248
- getMappedOffsetAndSize (linalgOp, b, indexingMap, offsets, sizes,
249
- mappedOffsets, mappedSizes);
250
- auto tilingInterfaceOp = cast<TilingInterface>(op);
251
- FailureOr<TilingResult> tilingResult =
252
- tilingInterfaceOp.getTiledImplementation (b, mappedOffsets, mappedSizes);
253
180
254
- if (failed (tilingResult))
255
- return failure ();
181
+ auto numLoops = linalgOp.getNumLoops ();
182
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
183
+ SmallVector<OpFoldResult> iterationTileOffsets (numLoops),
184
+ iterationTileSizes (numLoops);
185
+ if (!indexingMap.isPermutation ()) {
186
+ SmallVector<Range> iterationDomain =
187
+ tilingInterfaceOp.getIterationDomain (b);
188
+ for (const auto &range : llvm::enumerate (iterationDomain)) {
189
+ iterationTileOffsets[range.index ()] = range.value ().offset ;
190
+ iterationTileSizes[range.index ()] = range.value ().size ;
191
+ }
192
+ }
193
+ for (const auto &resultExpr : llvm::enumerate (indexingMap.getResults ())) {
194
+ unsigned dimPosition =
195
+ cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
196
+ iterationTileOffsets[dimPosition] = offsets[resultExpr.index ()];
197
+ iterationTileSizes[dimPosition] = sizes[resultExpr.index ()];
198
+ }
256
199
200
+ FailureOr<TilingResult> tilingResult =
201
+ tilingInterfaceOp.getTiledImplementation (b, iterationTileOffsets,
202
+ iterationTileSizes);
257
203
if (tilingResult->tiledOps .size () != 1 )
258
204
return op->emitOpError (" failed to generate tiled implementation" );
259
205
0 commit comments