Skip to content

Commit 67c2e35

Browse files
authored
[mlir] Add inferContractionDims util for indexing map inputs (llvm#76081)
This PR adds a util function to infer contraction dimensions given only the indexing maps of a linalg operation.
1 parent 9943cd7 commit 67c2e35

File tree

2 files changed

+65
-28
lines changed

2 files changed

+65
-28
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ struct ContractionDimensions {
6262
/// `k`, indices are returned in sorted order.
6363
/// Returns a failure if any of `m`, `n` or `k` is empty.
6464
FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
65+
FailureOr<ContractionDimensions>
66+
inferContractionDims(ArrayRef<AffineMap> indexingMaps);
6567

6668
/// Checks whether `linalgOp` conforms to ContractionOpInterface.
6769
// TODO: embed within `isa<ContractionOpInterface>` if possible / natural.

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -176,22 +176,22 @@ static bool isContractionBody(Block &block) {
176176
return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
177177
}
178178

179-
/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
180-
/// iterators of type `iter` that index the `opOperand` as a permutation.
181-
/// This is useful to infer various subcomputations on a given `linalgOp`.
182-
/// This is performed by looking up each result in the matching indexing map and
183-
/// determining whether:
179+
/// Given an `indexingMap` and its corresponding `iterators`, returns
180+
/// the positions of the iterators of type `iter` that are indexed by
181+
/// the `indexingMap` as a permutation. This is useful to infer various
182+
/// subcomputations on a `LinalgOp`. This is performed by looking up
183+
/// each result in the `indexingMap` and determining whether:
184184
/// - It is a single AffineDimExpr.
185185
/// - It is the only result involving this AffineDimExpr.
186186
static llvm::SmallDenseSet<int64_t>
187-
findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
187+
findPermutationsIndexingOperand(AffineMap indexingMap,
188+
ArrayRef<utils::IteratorType> iterators,
188189
utils::IteratorType iter) {
190+
assert(iterators.size() == indexingMap.getNumDims());
189191
llvm::SmallDenseSet<int64_t> res;
190-
assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
191-
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
192192
for (AffineExpr e : indexingMap.getResults()) {
193193
if (auto d = dyn_cast<AffineDimExpr>(e)) {
194-
if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
194+
if (iterators[d.getPosition()] == iter &&
195195
llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
196196
return e.isFunctionOfDim(d.getPosition());
197197
}) == 1)
@@ -206,6 +206,21 @@ auto par = utils::IteratorType::parallel;
206206
auto red = utils::IteratorType::reduction;
207207
} // namespace
208208

209+
/// Infer the iterator types from the init affine map. This looks at which dims
210+
/// are present in the map results, and returns an iterator types array with
211+
/// parallel types for dims that are present, and reduction types for dims that
212+
/// are not present.
213+
static FailureOr<SmallVector<utils::IteratorType>>
214+
inferIteratorsFromOutMap(AffineMap map) {
215+
if (!map.isProjectedPermutation())
216+
return failure();
217+
SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
218+
for (auto expr : map.getResults())
219+
if (auto dim = dyn_cast<AffineDimExpr>(expr))
220+
iterators[dim.getPosition()] = par;
221+
return iterators;
222+
}
223+
209224
/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
210225
/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
211226
/// 1. The m dimension is involved in an outer-product along LHS
@@ -217,17 +232,15 @@ auto red = utils::IteratorType::reduction;
217232
/// 5. Optional batch dimensions that appear in all operands are captured.
218233
/// This allows e.g. detecting that some contraction is embedded within
219234
/// `linalgOp` with some orthogonal heuristic.
220-
FailureOr<ContractionDimensions>
221-
mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
222-
if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
223-
return failure();
224-
225-
llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
226-
linalgOp, linalgOp.getDpsInputOperand(0), par);
227-
llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
228-
linalgOp, linalgOp.getDpsInputOperand(1), par);
229-
llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
230-
linalgOp, linalgOp.getDpsInitOperand(0), par);
235+
static FailureOr<ContractionDimensions>
236+
inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
237+
ArrayRef<utils::IteratorType> iterators) {
238+
llvm::SmallDenseSet<int64_t> a =
239+
findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
240+
llvm::SmallDenseSet<int64_t> b =
241+
findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
242+
llvm::SmallDenseSet<int64_t> c =
243+
findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
231244

232245
// A & C - B are the iterators involved in an outer-product along A (the LHS).
233246
llvm::SmallDenseSet<int64_t> ac = a;
@@ -243,10 +256,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
243256
llvm::set_intersect(batches, c);
244257

245258
// A & B red are the reduction dimensions.
246-
llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
247-
linalgOp, linalgOp.getDpsInputOperand(0), red);
248-
llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
249-
linalgOp, linalgOp.getDpsInputOperand(1), red);
259+
llvm::SmallDenseSet<int64_t> ra =
260+
findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
261+
llvm::SmallDenseSet<int64_t> rb =
262+
findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
250263
llvm::set_intersect(ra, rb);
251264

252265
// Return each set in sorted order.
@@ -262,6 +275,24 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
262275
return dimensions;
263276
}
264277

278+
FailureOr<ContractionDimensions>
279+
mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
280+
if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
281+
return failure();
282+
return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
283+
linalgOp.getIteratorTypesArray());
284+
}
285+
286+
FailureOr<ContractionDimensions>
287+
mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
288+
if (indexingMaps.size() != 3)
289+
return failure();
290+
auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
291+
if (failed(iterators))
292+
return failure();
293+
return inferContractionDimsImpl(indexingMaps, iterators.value());
294+
}
295+
265296
namespace mlir::linalg::detail {
266297
enum class MatchContractionResult {
267298
Success = 0,
@@ -504,10 +535,14 @@ static FailureOr<ConvolutionDimensions>
504535
inferConvolutionDimsImpl(LinalgOp linalgOp,
505536
ConvAccessExprWalker &inputExprWalker,
506537
bool allowEmptyConvolvedDims) {
538+
auto filterMap =
539+
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
540+
auto outputMap =
541+
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
507542
llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
508-
linalgOp, linalgOp.getDpsInputOperand(1), par);
543+
filterMap, linalgOp.getIteratorTypesArray(), par);
509544
llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
510-
linalgOp, linalgOp.getDpsInitOperand(0), par);
545+
outputMap, linalgOp.getIteratorTypesArray(), par);
511546

512547
// unConvolvedDims & outputDims - filterDims are the batch iterators.
513548
llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
@@ -529,8 +564,8 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
529564
llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
530565

531566
llvm::SmallDenseSet<int64_t> filterReducedDims =
532-
findPermutationsIndexingOperand(linalgOp, linalgOp.getDpsInputOperand(1),
533-
red);
567+
findPermutationsIndexingOperand(filterMap,
568+
linalgOp.getIteratorTypesArray(), red);
534569

535570
// convolvedDims & filterReducedDims are the filter loop iterators.
536571
llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;

0 commit comments

Comments
 (0)