-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add inferContractionDims util for indexing map inputs #76081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (Max191) ChangesThis PR adds a util function to infer contraction dimensions given only the indexing maps of a linalg operation. Full diff: https://github.com/llvm/llvm-project/pull/76081.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6c8240267e7d05..f92843a1dcb987 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -62,6 +62,8 @@ struct ContractionDimensions {
/// `k`, indices are returned in sorted order.
/// Returns a failure if any of `m`, `n` or `k` is empty.
FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
+FailureOr<ContractionDimensions>
+inferContractionDims(ArrayRef<AffineMap> indexingMaps);
/// Checks whether `linalgOp` conforms to ContractionOpInterface.
// TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba419d32f22a3e..78a13017ae5c3e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -176,22 +176,14 @@ static bool isContractionBody(Block &block) {
return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
}
-/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
-/// iterators of type `iter` that index the `opOperand` as a permutation.
-/// This is useful to infer various subcomputations on a given `linalgOp`.
-/// This is performed by looking up each result in the matching indexing map and
-/// determining whether:
-/// - It is a single AffineDimExpr.
-/// - It is the only result involving this AffineDimExpr.
static llvm::SmallDenseSet<int64_t>
-findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
- utils::IteratorType iter) {
+findPermutationsIndexingOperandImpl(AffineMap indexingMap,
+ ArrayRef<utils::IteratorType> iterators,
+ utils::IteratorType iter) {
llvm::SmallDenseSet<int64_t> res;
- assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
for (AffineExpr e : indexingMap.getResults()) {
if (auto d = dyn_cast<AffineDimExpr>(e)) {
- if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+ if (iterators[d.getPosition()] == iter &&
llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
return e.isFunctionOfDim(d.getPosition());
}) == 1)
@@ -201,11 +193,48 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
return res;
}
+/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
+/// iterators of type `iter` that index the `opOperand` as a permutation.
+/// This is useful to infer various subcomputations on a given `linalgOp`.
+/// This is performed by looking up each result in the matching indexing map and
+/// determining whether:
+/// - It is a single AffineDimExpr.
+/// - It is the only result involving this AffineDimExpr.
+static llvm::SmallDenseSet<int64_t>
+findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
+ utils::IteratorType iter) {
+ assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
+ return findPermutationsIndexingOperandImpl(
+ linalgOp.getMatchingIndexingMap(opOperand),
+ linalgOp.getIteratorTypesArray(), iter);
+}
+
+static llvm::SmallDenseSet<int64_t>
+findPermutationsIndexingOperand(AffineMap indexingMap,
+ ArrayRef<utils::IteratorType> iterators,
+ utils::IteratorType iter) {
+ return findPermutationsIndexingOperandImpl(indexingMap, iterators, iter);
+}
+
namespace {
auto par = utils::IteratorType::parallel;
auto red = utils::IteratorType::reduction;
} // namespace
+/// Infer the iterator types from the init affine map. This looks at which dims
+/// are present in the map results, and returns an iterator types array with
+/// parallel types for dims that are present, and reduction types for dims that
+/// are not present.
+static ArrayRef<utils::IteratorType> inferIteratorsFromOutMap(AffineMap map) {
+ SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
+ for (auto expr : map.getResults()) {
+ if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
+ iterators[dim.getPosition()] = par;
+ }
+ }
+ return iterators;
+}
+
/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
/// 1. The m dimension is involved in an outer-product along LHS
@@ -218,16 +247,14 @@ auto red = utils::IteratorType::reduction;
/// This allows e.g. detecting that some contraction is embedded within
/// `linalgOp` with some orthogonal heuristic.
FailureOr<ContractionDimensions>
-mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
- if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
- return failure();
-
- llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), par);
- llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), par);
- llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInitOperand(0), par);
+inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<utils::IteratorType> iterators) {
+ llvm::SmallDenseSet<int64_t> a =
+ findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
+ llvm::SmallDenseSet<int64_t> b =
+ findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
+ llvm::SmallDenseSet<int64_t> c =
+ findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
// A & C - B are the iterators involved in an outer-product along A (the LHS).
llvm::SmallDenseSet<int64_t> ac = a;
@@ -243,10 +270,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
llvm::set_intersect(batches, c);
// A & B red are the reduction dimensions.
- llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), red);
- llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), red);
+ llvm::SmallDenseSet<int64_t> ra =
+ findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
+ llvm::SmallDenseSet<int64_t> rb =
+ findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
llvm::set_intersect(ra, rb);
// Return each set in sorted order.
@@ -262,6 +289,22 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
return dimensions;
}
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
+ if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+ return failure();
+ return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
+ linalgOp.getIteratorTypesArray());
+}
+
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
+ if (indexingMaps.size() != 3)
+ return failure();
+ return inferContractionDimsImpl(indexingMaps,
+ inferIteratorsFromOutMap(indexingMaps[2]));
+}
+
namespace mlir::linalg::detail {
enum class MatchContractionResult {
Success = 0,
|
@@ -218,16 +247,14 @@ auto red = utils::IteratorType::reduction; | |||
/// This allows e.g. detecting that some contraction is embedded within | |||
/// `linalgOp` with some orthogonal heuristic. | |||
FailureOr<ContractionDimensions> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Make static
} | ||
|
||
static llvm::SmallDenseSet<int64_t> | ||
findPermutationsIndexingOperand(AffineMap indexingMap, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this static function? I think you can just call the impl directly where appropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I personally prefer not to call the impl function and just use overloads, but I can change it if you prefer me to directly call the impl given that it is just a static helper.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just findPermutationsIndexingOperandImpl
-> findPermutationsIndexingOperand
then and drop this?
/// are present in the map results, and returns an iterator types array with | ||
/// parallel types for dims that are present, and reduction types for dims that | ||
/// are not present. | ||
static ArrayRef<utils::IteratorType> inferIteratorsFromOutMap(AffineMap map) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Given that inferContractionDims is an API function, it might be nice to make this FailureOr
and fail if not a projected permutation.
a4614e8
to
1a6b0b5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok this looks fine to me, but given there is no test here and it's just changing API surface, I'm not an authority on what to/how to make API changes like this so give others a chance to review before landing.
} | ||
|
||
static llvm::SmallDenseSet<int64_t> | ||
findPermutationsIndexingOperand(AffineMap indexingMap, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just findPermutationsIndexingOperandImpl
-> findPermutationsIndexingOperand
then and drop this?
/// are present in the map results, and returns an iterator types array with | ||
/// parallel types for dims that are present, and reduction types for dims that | ||
/// are not present. | ||
static FailureOr<ArrayRef<utils::IteratorType>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I would just return the SmallVector directly
} | ||
} | ||
if (iterators.size() != map.getNumDims()) | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This check does not seem possible given you aren't changing the size of the vector.
/// are present in the map results, and returns an iterator types array with | ||
/// parallel types for dims that are present, and reduction types for dims that | ||
/// are not present. | ||
static FailureOr<ArrayRef<utils::IteratorType>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this method correct? ArrayRef does not own the underlying data and you return the small vector allocated in the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to be working for me, but probably better to return SmallVector. I'll update
@@ -15,9 +15,11 @@ | |||
#include "mlir/Dialect/Linalg/IR/Linalg.h" | |||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | |||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | |||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these includes needed?
@@ -201,11 +203,37 @@ findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand, | |||
return res; | |||
} | |||
|
|||
static llvm::SmallDenseSet<int64_t> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would drop this method. It is used only for inferConvolutionDimsImpl
.
if (!map.isProjectedPermutation()) | ||
return failure(); | ||
SmallVector<utils::IteratorType> iterators(map.getNumDims(), red); | ||
for (auto expr : map.getResults()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: drop braces for simple single stmt body if and for.
7b03181
to
a92bd97
Compare
a92bd97
to
8dd51b8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you
This PR adds a util function to infer contraction dimensions given only the indexing maps of a linalg operation.