@@ -132,6 +132,58 @@ struct LinalgOpTilingInterface
132
132
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
133
133
}
134
134
135
+ void getMappedOffsetAndSize (LinalgOp linalgOp, OpBuilder &b,
136
+ AffineMap indexingMap,
137
+ ArrayRef<OpFoldResult> offsets,
138
+ ArrayRef<OpFoldResult> sizes,
139
+ SmallVector<OpFoldResult> &mappedOffsets,
140
+ SmallVector<OpFoldResult> &mappedSizes) const {
141
+ auto numLoops = linalgOp.getNumLoops ();
142
+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation ());
143
+ mappedOffsets.resize (numLoops);
144
+ mappedSizes.resize (numLoops);
145
+ if (!indexingMap.isPermutation ()) {
146
+ SmallVector<Range> iterationDomain =
147
+ tilingInterfaceOp.getIterationDomain (b);
148
+ for (const auto &range : llvm::enumerate (iterationDomain)) {
149
+ mappedOffsets[range.index ()] = range.value ().offset ;
150
+ mappedSizes[range.index ()] = range.value ().size ;
151
+ }
152
+ }
153
+ for (const auto &resultExpr : llvm::enumerate (indexingMap.getResults ())) {
154
+ unsigned dimPosition =
155
+ cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
156
+ mappedOffsets[dimPosition] = offsets[resultExpr.index ()];
157
+ mappedSizes[dimPosition] = sizes[resultExpr.index ()];
158
+ }
159
+ }
160
+
161
+ // Return the details of the output tile generated by the tiled
162
+ // implementation.
163
+ LogicalResult getIterationDomainTileFromOperandTile (
164
+ Operation *op, OpBuilder &b, unsigned operandNumber,
165
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
166
+ SmallVector<OpFoldResult> &iterDomainOffsets,
167
+ SmallVector<OpFoldResult> &iterDomainSizes) const {
168
+ auto linalgOp = cast<LinalgOp>(op);
169
+
170
+ // Check that the indexing map used for the operand is a projected
171
+ // permutation. This could be relaxed with a more general approach that can
172
+ // map the offsets and sizes from the operand to iteration space tiles
173
+ // (filling in full extent for dimensions not used to access the result).
174
+ AffineMap indexingMap =
175
+ linalgOp.getMatchingIndexingMap (&op->getOpOperand (operandNumber));
176
+ if (!indexingMap.isProjectedPermutation ()) {
177
+ return op->emitOpError (
178
+ " unhandled get iter domain position when operand is not "
179
+ " accessed using a permuted projection" );
180
+ }
181
+
182
+ getMappedOffsetAndSize (linalgOp, b, indexingMap, offsets, sizes,
183
+ iterDomainOffsets, iterDomainSizes);
184
+ return success ();
185
+ }
186
+
135
187
// Return the details of the output tile generated by the tiled
136
188
// implementation.
137
189
LogicalResult
@@ -160,6 +212,20 @@ struct LinalgOpTilingInterface
160
212
return success ();
161
213
}
162
214
215
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile (
216
+ Operation *op, OpBuilder &b, unsigned operandNumber,
217
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
218
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
219
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
220
+ if (failed (tilingInterfaceOp.getIterationDomainTileFromOperandTile (
221
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
222
+ return op->emitOpError (
223
+ " unable to obtain the iter domain position of the operation." );
224
+ }
225
+ return tilingInterfaceOp.getTiledImplementation (b, mappedOffsets,
226
+ mappedSizes);
227
+ }
228
+
163
229
FailureOr<TilingResult>
164
230
generateResultTileValue (Operation *op, OpBuilder &b, unsigned resultNumber,
165
231
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +243,16 @@ struct LinalgOpTilingInterface
177
243
" unhandled tiled implementation generation when result is not "
178
244
" accessed using a permuted projection" );
179
245
}
180
-
181
- auto numLoops = linalgOp.getNumLoops ();
246
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
247
+ getMappedOffsetAndSize (linalgOp, b, indexingMap, offsets, sizes,
248
+ mappedOffsets, mappedSizes);
182
249
auto tilingInterfaceOp = cast<TilingInterface>(op);
183
- SmallVector<OpFoldResult> iterationTileOffsets (numLoops),
184
- iterationTileSizes (numLoops);
185
- if (!indexingMap.isPermutation ()) {
186
- SmallVector<Range> iterationDomain =
187
- tilingInterfaceOp.getIterationDomain (b);
188
- for (const auto &range : llvm::enumerate (iterationDomain)) {
189
- iterationTileOffsets[range.index ()] = range.value ().offset ;
190
- iterationTileSizes[range.index ()] = range.value ().size ;
191
- }
192
- }
193
- for (const auto &resultExpr : llvm::enumerate (indexingMap.getResults ())) {
194
- unsigned dimPosition =
195
- cast<AffineDimExpr>(resultExpr.value ()).getPosition ();
196
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index ()];
197
- iterationTileSizes[dimPosition] = sizes[resultExpr.index ()];
198
- }
199
-
200
250
FailureOr<TilingResult> tilingResult =
201
- tilingInterfaceOp.getTiledImplementation (b, iterationTileOffsets,
202
- iterationTileSizes);
251
+ tilingInterfaceOp.getTiledImplementation (b, mappedOffsets, mappedSizes);
252
+
253
+ if (failed (tilingResult))
254
+ return failure ();
255
+
203
256
if (tilingResult->tiledOps .size () != 1 )
204
257
return op->emitOpError (" failed to generate tiled implementation" );
205
258
0 commit comments