Skip to content

[mlir] Add inferContractionDims util for indexing map inputs #76081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ struct ContractionDimensions {
/// `k`, indices are returned in sorted order.
/// Returns a failure if any of `m`, `n` or `k` is empty.
FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
FailureOr<ContractionDimensions>
inferContractionDims(ArrayRef<AffineMap> indexingMaps);

/// Checks whether `linalgOp` conforms to ContractionOpInterface.
// TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
Expand Down
91 changes: 63 additions & 28 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,22 +176,22 @@ static bool isContractionBody(Block &block) {
return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
}

/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
/// iterators of type `iter` that index the `opOperand` as a permutation.
/// This is useful to infer various subcomputations on a given `linalgOp`.
/// This is performed by looking up each result in the matching indexing map and
/// determining whether:
/// Given an `indexingMap` and its corresponding `iterators`, returns
/// the positions of the iterators of type `iter` that are indexed by
/// the `indexingMap` as a permutation. This is useful to infer various
/// subcomputations on a `LinalgOp`. This is performed by looking up
/// each result in the `indexingMap` and determining whether:
/// - It is a single AffineDimExpr.
/// - It is the only result involving this AffineDimExpr.
static llvm::SmallDenseSet<int64_t>
findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
findPermutationsIndexingOperand(AffineMap indexingMap,
ArrayRef<utils::IteratorType> iterators,
utils::IteratorType iter) {
assert(iterators.size() == indexingMap.getNumDims());
llvm::SmallDenseSet<int64_t> res;
assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
for (AffineExpr e : indexingMap.getResults()) {
if (auto d = dyn_cast<AffineDimExpr>(e)) {
if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
if (iterators[d.getPosition()] == iter &&
llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
return e.isFunctionOfDim(d.getPosition());
}) == 1)
Expand All @@ -206,6 +206,21 @@ auto par = utils::IteratorType::parallel;
auto red = utils::IteratorType::reduction;
} // namespace

/// Infer the iterator types from the init affine map. This looks at which dims
/// are present in the map results, and returns an iterator types array with
/// parallel types for dims that are present, and reduction types for dims that
/// are not present.
static FailureOr<SmallVector<utils::IteratorType>>
inferIteratorsFromOutMap(AffineMap map) {
if (!map.isProjectedPermutation())
return failure();
SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
for (auto expr : map.getResults())
if (auto dim = dyn_cast<AffineDimExpr>(expr))
iterators[dim.getPosition()] = par;
return iterators;
}

/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
/// 1. The m dimension is involved in an outer-product along LHS
Expand All @@ -217,17 +232,15 @@ auto red = utils::IteratorType::reduction;
/// 5. Optional batch dimensions that appear in all operands are captured.
/// This allows e.g. detecting that some contraction is embedded within
/// `linalgOp` with some orthogonal heuristic.
FailureOr<ContractionDimensions>
mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
return failure();

llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(0), par);
llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(1), par);
llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInitOperand(0), par);
static FailureOr<ContractionDimensions>
inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
ArrayRef<utils::IteratorType> iterators) {
llvm::SmallDenseSet<int64_t> a =
findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
llvm::SmallDenseSet<int64_t> b =
findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
llvm::SmallDenseSet<int64_t> c =
findPermutationsIndexingOperand(indexingMaps[2], iterators, par);

// A & C - B are the iterators involved in an outer-product along A (the LHS).
llvm::SmallDenseSet<int64_t> ac = a;
Expand All @@ -243,10 +256,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
llvm::set_intersect(batches, c);

// A & B red are the reduction dimensions.
llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(0), red);
llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(1), red);
llvm::SmallDenseSet<int64_t> ra =
findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
llvm::SmallDenseSet<int64_t> rb =
findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
llvm::set_intersect(ra, rb);

// Return each set in sorted order.
Expand All @@ -262,6 +275,24 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
return dimensions;
}

FailureOr<ContractionDimensions>
mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
return failure();
return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
linalgOp.getIteratorTypesArray());
}

FailureOr<ContractionDimensions>
mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
if (indexingMaps.size() != 3)
return failure();
auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
if (failed(iterators))
return failure();
return inferContractionDimsImpl(indexingMaps, iterators.value());
}

namespace mlir::linalg::detail {
enum class MatchContractionResult {
Success = 0,
Expand Down Expand Up @@ -504,10 +535,14 @@ static FailureOr<ConvolutionDimensions>
inferConvolutionDimsImpl(LinalgOp linalgOp,
ConvAccessExprWalker &inputExprWalker,
bool allowEmptyConvolvedDims) {
auto filterMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
auto outputMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInputOperand(1), par);
filterMap, linalgOp.getIteratorTypesArray(), par);
llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
linalgOp, linalgOp.getDpsInitOperand(0), par);
outputMap, linalgOp.getIteratorTypesArray(), par);

// unConvolvedDims & outputDims - filterDims are the batch iterators.
llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
Expand All @@ -529,8 +564,8 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);

llvm::SmallDenseSet<int64_t> filterReducedDims =
findPermutationsIndexingOperand(linalgOp, linalgOp.getDpsInputOperand(1),
red);
findPermutationsIndexingOperand(filterMap,
linalgOp.getIteratorTypesArray(), red);

// convolvedDims & filterReducedDims are the filter loop iterators.
llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
Expand Down