@@ -176,22 +176,22 @@ static bool isContractionBody(Block &block) {
176
176
return linalg::detail::isContractionBody (block, &isPairTemplateImpl<Args...>);
177
177
}
178
178
179
- // / Given a `linalgOp ` and one of its `opOperand `, returns the positions of the
180
- // / iterators of type `iter` that index the `opOperand` as a permutation.
181
- // / This is useful to infer various subcomputations on a given `linalgOp`.
182
- // / This is performed by looking up each result in the matching indexing map and
183
- // / determining whether:
179
+ // / Given an `indexingMap ` and its corresponding `iterators `, returns
180
+ // / the positions of the iterators of type `iter` that are indexed by
181
+ // / the `indexingMap` as a permutation. This is useful to infer various
182
+ // / subcomputations on a `LinalgOp`. This is performed by looking up
183
+ // / each result in the `indexingMap` and determining whether:
184
184
// / - It is a single AffineDimExpr.
185
185
// / - It is the only result involving this AffineDimExpr.
186
186
static llvm::SmallDenseSet<int64_t >
187
- findPermutationsIndexingOperand (LinalgOp linalgOp, OpOperand *opOperand,
187
+ findPermutationsIndexingOperand (AffineMap indexingMap,
188
+ ArrayRef<utils::IteratorType> iterators,
188
189
utils::IteratorType iter) {
190
+ assert (iterators.size () == indexingMap.getNumDims ());
189
191
llvm::SmallDenseSet<int64_t > res;
190
- assert (linalgOp == opOperand->getOwner () && " expected linalgOp owner" );
191
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap (opOperand);
192
192
for (AffineExpr e : indexingMap.getResults ()) {
193
193
if (auto d = dyn_cast<AffineDimExpr>(e)) {
194
- if (linalgOp. getIteratorTypesArray () [d.getPosition ()] == iter &&
194
+ if (iterators [d.getPosition ()] == iter &&
195
195
llvm::count_if (indexingMap.getResults (), [d](AffineExpr e) {
196
196
return e.isFunctionOfDim (d.getPosition ());
197
197
}) == 1 )
@@ -206,6 +206,21 @@ auto par = utils::IteratorType::parallel;
206
206
auto red = utils::IteratorType::reduction;
207
207
} // namespace
208
208
209
+ // / Infer the iterator types from the init affine map. This looks at which dims
210
+ // / are present in the map results, and returns an iterator types array with
211
+ // / parallel types for dims that are present, and reduction types for dims that
212
+ // / are not present.
213
+ static FailureOr<SmallVector<utils::IteratorType>>
214
+ inferIteratorsFromOutMap (AffineMap map) {
215
+ if (!map.isProjectedPermutation ())
216
+ return failure ();
217
+ SmallVector<utils::IteratorType> iterators (map.getNumDims (), red);
218
+ for (auto expr : map.getResults ())
219
+ if (auto dim = dyn_cast<AffineDimExpr>(expr))
220
+ iterators[dim.getPosition ()] = par;
221
+ return iterators;
222
+ }
223
+
209
224
// / Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
210
225
// / a matmul subcomputation within `linalgOp`. These dimensions are such that:
211
226
// / 1. The m dimension is involved in an outer-product along LHS
@@ -217,17 +232,15 @@ auto red = utils::IteratorType::reduction;
217
232
// / 5. Optional batch dimensions that appear in all operands are captured.
218
233
// / This allows e.g. detecting that some contraction is embedded within
219
234
// / `linalgOp` with some orthogonal heuristic.
220
- FailureOr<ContractionDimensions>
221
- mlir::linalg::inferContractionDims (LinalgOp linalgOp) {
222
- if (linalgOp.getNumDpsInits () != 1 || linalgOp.getNumDpsInputs () != 2 )
223
- return failure ();
224
-
225
- llvm::SmallDenseSet<int64_t > a = findPermutationsIndexingOperand (
226
- linalgOp, linalgOp.getDpsInputOperand (0 ), par);
227
- llvm::SmallDenseSet<int64_t > b = findPermutationsIndexingOperand (
228
- linalgOp, linalgOp.getDpsInputOperand (1 ), par);
229
- llvm::SmallDenseSet<int64_t > c = findPermutationsIndexingOperand (
230
- linalgOp, linalgOp.getDpsInitOperand (0 ), par);
235
+ static FailureOr<ContractionDimensions>
236
+ inferContractionDimsImpl (ArrayRef<AffineMap> indexingMaps,
237
+ ArrayRef<utils::IteratorType> iterators) {
238
+ llvm::SmallDenseSet<int64_t > a =
239
+ findPermutationsIndexingOperand (indexingMaps[0 ], iterators, par);
240
+ llvm::SmallDenseSet<int64_t > b =
241
+ findPermutationsIndexingOperand (indexingMaps[1 ], iterators, par);
242
+ llvm::SmallDenseSet<int64_t > c =
243
+ findPermutationsIndexingOperand (indexingMaps[2 ], iterators, par);
231
244
232
245
// A & C - B are the iterators involved in an outer-product along A (the LHS).
233
246
llvm::SmallDenseSet<int64_t > ac = a;
@@ -243,10 +256,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
243
256
llvm::set_intersect (batches, c);
244
257
245
258
// A & B red are the reduction dimensions.
246
- llvm::SmallDenseSet<int64_t > ra = findPermutationsIndexingOperand (
247
- linalgOp, linalgOp. getDpsInputOperand ( 0 ) , red);
248
- llvm::SmallDenseSet<int64_t > rb = findPermutationsIndexingOperand (
249
- linalgOp, linalgOp. getDpsInputOperand ( 1 ) , red);
259
+ llvm::SmallDenseSet<int64_t > ra =
260
+ findPermutationsIndexingOperand (indexingMaps[ 0 ], iterators , red);
261
+ llvm::SmallDenseSet<int64_t > rb =
262
+ findPermutationsIndexingOperand (indexingMaps[ 1 ], iterators , red);
250
263
llvm::set_intersect (ra, rb);
251
264
252
265
// Return each set in sorted order.
@@ -262,6 +275,24 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
262
275
return dimensions;
263
276
}
264
277
278
+ FailureOr<ContractionDimensions>
279
+ mlir::linalg::inferContractionDims (LinalgOp linalgOp) {
280
+ if (linalgOp.getNumDpsInits () != 1 || linalgOp.getNumDpsInputs () != 2 )
281
+ return failure ();
282
+ return inferContractionDimsImpl (linalgOp.getIndexingMapsArray (),
283
+ linalgOp.getIteratorTypesArray ());
284
+ }
285
+
286
+ FailureOr<ContractionDimensions>
287
+ mlir::linalg::inferContractionDims (ArrayRef<AffineMap> indexingMaps) {
288
+ if (indexingMaps.size () != 3 )
289
+ return failure ();
290
+ auto iterators = inferIteratorsFromOutMap (indexingMaps[2 ]);
291
+ if (failed (iterators))
292
+ return failure ();
293
+ return inferContractionDimsImpl (indexingMaps, iterators.value ());
294
+ }
295
+
265
296
namespace mlir ::linalg::detail {
266
297
enum class MatchContractionResult {
267
298
Success = 0 ,
@@ -504,10 +535,14 @@ static FailureOr<ConvolutionDimensions>
504
535
inferConvolutionDimsImpl (LinalgOp linalgOp,
505
536
ConvAccessExprWalker &inputExprWalker,
506
537
bool allowEmptyConvolvedDims) {
538
+ auto filterMap =
539
+ linalgOp.getMatchingIndexingMap (linalgOp.getDpsInputOperand (1 ));
540
+ auto outputMap =
541
+ linalgOp.getMatchingIndexingMap (linalgOp.getDpsInitOperand (0 ));
507
542
llvm::SmallDenseSet<int64_t > filterDims = findPermutationsIndexingOperand (
508
- linalgOp , linalgOp.getDpsInputOperand ( 1 ), par);
543
+ filterMap , linalgOp.getIteratorTypesArray ( ), par);
509
544
llvm::SmallDenseSet<int64_t > outputDims = findPermutationsIndexingOperand (
510
- linalgOp , linalgOp.getDpsInitOperand ( 0 ), par);
545
+ outputMap , linalgOp.getIteratorTypesArray ( ), par);
511
546
512
547
// unConvolvedDims & outputDims - filterDims are the batch iterators.
513
548
llvm::SmallDenseSet<int64_t > batch = inputExprWalker.unConvolvedDims ;
@@ -529,8 +564,8 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
529
564
llvm::set_intersect (depth, inputExprWalker.unConvolvedDims );
530
565
531
566
llvm::SmallDenseSet<int64_t > filterReducedDims =
532
- findPermutationsIndexingOperand (linalgOp, linalgOp. getDpsInputOperand ( 1 ) ,
533
- red);
567
+ findPermutationsIndexingOperand (filterMap ,
568
+ linalgOp. getIteratorTypesArray (), red);
534
569
535
570
// convolvedDims & filterReducedDims are the filter loop iterators.
536
571
llvm::SmallDenseSet<int64_t > fl = inputExprWalker.convolvedDims ;
0 commit comments