Skip to content

[mlir] Add inferContractionDims util for indexing map inputs #76081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 2, 2024

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Dec 20, 2023

This PR adds a util function to infer contraction dimensions given only the indexing maps of a linalg operation.

@llvmbot
Copy link
Member

llvmbot commented Dec 20, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (Max191)

Changes

This PR adds a util function to infer contraction dimensions given only the indexing maps of a linalg operation.


Full diff: https://github.com/llvm/llvm-project/pull/76081.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (+2)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+69-26)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6c8240267e7d05..f92843a1dcb987 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -62,6 +62,8 @@ struct ContractionDimensions {
 /// `k`, indices are returned in sorted order.
 /// Returns a failure if any of `m`, `n` or `k` is empty.
 FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
+FailureOr<ContractionDimensions>
+inferContractionDims(ArrayRef<AffineMap> indexingMaps);
 
 /// Checks whether `linalgOp` conforms to ContractionOpInterface.
 // TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba419d32f22a3e..78a13017ae5c3e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -176,22 +176,14 @@ static bool isContractionBody(Block &block) {
   return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
 }
 
-/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
-/// iterators of type `iter` that index the `opOperand` as a permutation.
-/// This is useful to infer various subcomputations on a given `linalgOp`.
-/// This is performed by looking up each result in the matching indexing map and
-/// determining whether:
-///   - It is a single AffineDimExpr.
-///   - It is the only result involving this AffineDimExpr.
 static llvm::SmallDenseSet<int64_t>
-findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
-                                utils::IteratorType iter) {
+findPermutationsIndexingOperandImpl(AffineMap indexingMap,
+                                    ArrayRef<utils::IteratorType> iterators,
+                                    utils::IteratorType iter) {
   llvm::SmallDenseSet<int64_t> res;
-  assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
-  AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
   for (AffineExpr e : indexingMap.getResults()) {
     if (auto d = dyn_cast<AffineDimExpr>(e)) {
-      if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+      if (iterators[d.getPosition()] == iter &&
           llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
             return e.isFunctionOfDim(d.getPosition());
           }) == 1)
@@ -201,11 +193,48 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
   return res;
 }
 
+/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
+/// iterators of type `iter` that index the `opOperand` as a permutation.
+/// This is useful to infer various subcomputations on a given `linalgOp`.
+/// This is performed by looking up each result in the matching indexing map and
+/// determining whether:
+///   - It is a single AffineDimExpr.
+///   - It is the only result involving this AffineDimExpr.
+static llvm::SmallDenseSet<int64_t>
+findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
+                                utils::IteratorType iter) {
+  assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
+  return findPermutationsIndexingOperandImpl(
+      linalgOp.getMatchingIndexingMap(opOperand),
+      linalgOp.getIteratorTypesArray(), iter);
+}
+
+static llvm::SmallDenseSet<int64_t>
+findPermutationsIndexingOperand(AffineMap indexingMap,
+                                ArrayRef<utils::IteratorType> iterators,
+                                utils::IteratorType iter) {
+  return findPermutationsIndexingOperandImpl(indexingMap, iterators, iter);
+}
+
 namespace {
 auto par = utils::IteratorType::parallel;
 auto red = utils::IteratorType::reduction;
 } // namespace
 
+/// Infer the iterator types from the init affine map. This looks at which dims
+/// are present in the map results, and returns an iterator types array with
+/// parallel types for dims that are present, and reduction types for dims that
+/// are not present.
+static ArrayRef<utils::IteratorType> inferIteratorsFromOutMap(AffineMap map) {
+  SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
+  for (auto expr : map.getResults()) {
+    if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
+      iterators[dim.getPosition()] = par;
+    }
+  }
+  return iterators;
+}
+
 /// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
 /// a matmul subcomputation within `linalgOp`. These dimensions are such that:
 ///   1. The m dimension is involved in an outer-product along LHS
@@ -218,16 +247,14 @@ auto red = utils::IteratorType::reduction;
 /// This allows e.g. detecting that some contraction is embedded within
 /// `linalgOp` with some orthogonal heuristic.
 FailureOr<ContractionDimensions>
-mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
-  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
-    return failure();
-
-  llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(0), par);
-  llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(1), par);
-  llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInitOperand(0), par);
+inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
+                         ArrayRef<utils::IteratorType> iterators) {
+  llvm::SmallDenseSet<int64_t> a =
+      findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
+  llvm::SmallDenseSet<int64_t> b =
+      findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
+  llvm::SmallDenseSet<int64_t> c =
+      findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
 
   // A & C - B are the iterators involved in an outer-product along A (the LHS).
   llvm::SmallDenseSet<int64_t> ac = a;
@@ -243,10 +270,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
   llvm::set_intersect(batches, c);
 
   // A & B red are the reduction dimensions.
-  llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(0), red);
-  llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
-      linalgOp, linalgOp.getDpsInputOperand(1), red);
+  llvm::SmallDenseSet<int64_t> ra =
+      findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
+  llvm::SmallDenseSet<int64_t> rb =
+      findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
   llvm::set_intersect(ra, rb);
 
   // Return each set in sorted order.
@@ -262,6 +289,22 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
   return dimensions;
 }
 
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
+  if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+    return failure();
+  return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
+                                  linalgOp.getIteratorTypesArray());
+}
+
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
+  if (indexingMaps.size() != 3)
+    return failure();
+  return inferContractionDimsImpl(indexingMaps,
+                                  inferIteratorsFromOutMap(indexingMaps[2]));
+}
+
 namespace mlir::linalg::detail {
 enum class MatchContractionResult {
   Success = 0,

@@ -218,16 +247,14 @@ auto red = utils::IteratorType::reduction;
/// This allows e.g. detecting that some contraction is embedded within
/// `linalgOp` with some orthogonal heuristic.
FailureOr<ContractionDimensions>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Make static

}

static llvm::SmallDenseSet<int64_t>
findPermutationsIndexingOperand(AffineMap indexingMap,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this static function? I think you can just call the impl directly where appropriate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally prefer not to call the impl function and just use overloads, but I can change it if you prefer me to directly call the impl given that it is just a static helper.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just findPermutationsIndexingOperandImpl -> findPermutationsIndexingOperand then and drop this?

/// are present in the map results, and returns an iterator types array with
/// parallel types for dims that are present, and reduction types for dims that
/// are not present.
static ArrayRef<utils::IteratorType> inferIteratorsFromOutMap(AffineMap map) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Given that inferContractionDims is an API function, it might be nice to make this FailureOr and fail if not a projected permutation.

@Max191 Max191 force-pushed the contraction-op-interface-affinemaps branch from a4614e8 to 1a6b0b5 Compare December 21, 2023 15:18
@Max191 Max191 requested a review from qedawkins December 21, 2023 15:19
Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this looks fine to me, but given there is no test here and it's just changing API surface, I'm not an authority on what to/how to make API changes like this so give others a chance to review before landing.

}

static llvm::SmallDenseSet<int64_t>
findPermutationsIndexingOperand(AffineMap indexingMap,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just findPermutationsIndexingOperandImpl -> findPermutationsIndexingOperand then and drop this?

/// are present in the map results, and returns an iterator types array with
/// parallel types for dims that are present, and reduction types for dims that
/// are not present.
static FailureOr<ArrayRef<utils::IteratorType>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would just return the SmallVector directly

}
}
if (iterators.size() != map.getNumDims())
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This check does not seem possible given you aren't changing the size of the vector.

/// are present in the map results, and returns an iterator types array with
/// parallel types for dims that are present, and reduction types for dims that
/// are not present.
static FailureOr<ArrayRef<utils::IteratorType>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this method correct? ArrayRef does not own the underlying data and you return the small vector allocated in the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be working for me, but probably better to return SmallVector. I'll update

@@ -15,9 +15,11 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these includes needed?

@@ -201,11 +203,37 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
return res;
}

static llvm::SmallDenseSet<int64_t>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would drop this method. It is used only for inferConvolutionDimsImpl.

if (!map.isProjectedPermutation())
return failure();
SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
for (auto expr : map.getResults()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: drop braces for simple single stmt body if and for.

@Max191 Max191 force-pushed the contraction-op-interface-affinemaps branch from 7b03181 to a92bd97 Compare December 21, 2023 20:19
@Max191 Max191 force-pushed the contraction-op-interface-affinemaps branch from a92bd97 to 8dd51b8 Compare January 2, 2024 21:46
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you

@Max191 Max191 merged commit 67c2e35 into llvm:main Jan 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants