Skip to content

[MLIR][Linalg] Add pass to convert linalg.generic back to named ops #95656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def LinalgGeneralizeNamedOpsPass : Pass<"linalg-generalize-named-ops"> {
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgSpecializeGenericOpsPass : Pass<"linalg-specialize-generic-ops"> {
let summary = "Convert generic ops back to named ops";
let dependentDialects = ["linalg::LinalgDialect"];
}

def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
Expand Down
23 changes: 23 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,20 @@ struct LinalgGeneralizationPattern
}
};

struct LinalgSpecializationPattern : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;

FailureOr<GenericOp>
returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const {
return specializeGenericOp(rewriter, op);
}

LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
return returningMatchAndRewrite(op, rewriter);
}
};

/// Vectorization pattern for memref::CopyOp.
struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
Expand Down Expand Up @@ -1546,6 +1560,15 @@ void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns);
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns);

/// Populates `patterns` with patterns to convert linalg.generic ops to named
/// ops where possible. A linalg.generic can represent wide range and complex
/// computations for which equivalent linalg named op may not exist e.g.
/// linalg.generic that takes a tensor and computes a polynomial such as:
/// p(x) = an*x^n + ... + a1x + a0
/// There is no equivalent named op to convert to. Many such cases exist.
void populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns);

/// Linalg decompose convolutions patterns

/// Populates patterns to decompose high-D convolution ops into low-D ones.
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel, and have only tensor semantics.
// Check all loops are parallel.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
genericOp.getNumLoops() < 1)
return false;

// Check there are arity-inputs, 1-output and all are identity-maps.
Expand Down
229 changes: 229 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,22 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"

namespace mlir {
#define GEN_PASS_DEF_LINALGSPECIALIZEGENERICOPSPASS
#include "mlir/Dialect/Linalg/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE "linalg-specialization"

#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
Expand Down Expand Up @@ -58,6 +68,197 @@ static bool areBinOpsSwapped(GenericOp genericOp) {
return swapped;
}

//===----------------------------------------------------------------------===//
// Specialize linalg generic to matmul variants.
//===----------------------------------------------------------------------===//
/// Identifies linalg.generic that is essentially named op of the form:
// ` linalg.{batch_}?matmul{_transpose_a | _transpose_b}? `
//
// It is possible that a linalg.generic may be implementing a matmul but not
// in a straight-forward way e.g. below is matrix multiply over some slice
// ```
// %0 = linalg.generic {
// indexing_maps = [affine_map<(d0, d1, d2) -> (3, d1, d0)>,
// affine_map<(d0, d1, d2) -> (d0, 5, d2)>,
// affine_map<(d0, d1, d2) -> (d2, d1, 13)>],
// iterator_types = ["parallel", "parallel", "parallel"]}
// ins(%A, %B : tensor<20x20x20xf32>, tensor<20x20x20xf32>)
// outs(%C : tensor<20x20x20xf32>) {
// ^bb0(%a: f32, %b: f32, %c : f32):
// %mul = arith.mulf %a, %b : f32
// %add = arith.addf %mul, %c : f32
// linalg.yield %add : f32
// } -> tensor<20x20x20xf32>
Comment on lines +80 to +91
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you use it to write a negative test?

// ```
// It is not possible to represent above as named op.
// e.g. linalg.batch_matmul(%A, %B : tensor<20x20x20xf32>, ...) is
// not the same as linalg.generic above.
namespace {
enum class IndexMatchResult {
Match = 0, // identity map.
Transposed, // transposed map.
Mismatch // none of the above.
};

// Checks whether the input Affine `map` contains two consecutive dims that
// can be interpreted as accessing a 2D matrix. It is assumed that the row
// column dimension are adjacent axis (in this order) and start at
// `rowDimIdx` in the input map.
//
// e.g. consider A matrix in `C[M,N] = A[M,K] * B[K,N]`. We will check
// whether the map of A is identity (match), transposed, or something
// completely different (mis-match). Similar for B and C.
static IndexMatchResult matchOperandMap(AffineMap map, unsigned rowDimIdx,
unsigned expectedPosOfRowDim,
unsigned expectedPosOfColDim) {
// Get the matrix multiply indices. They are past the batch indices.
auto exprOfRowDim = map.getResults()[rowDimIdx];
auto exprOfColDim = map.getResults()[rowDimIdx + 1];

// They should be pure dimension ids.
if (exprOfRowDim.getKind() != AffineExprKind::DimId ||
exprOfColDim.getKind() != AffineExprKind::DimId)
return IndexMatchResult::Mismatch;

auto posRowDim = cast<AffineDimExpr>(exprOfRowDim).getPosition();
auto posColDim = cast<AffineDimExpr>(exprOfColDim).getPosition();

if (expectedPosOfRowDim == posRowDim && expectedPosOfColDim == posColDim)
return IndexMatchResult::Match;

if (expectedPosOfRowDim == posColDim && expectedPosOfColDim == posRowDim)
return IndexMatchResult::Transposed;

return IndexMatchResult::Mismatch;
}

// Replaces genericOp with `NamedOpTy` op, supplied as a template arg.
// All the variants expressed as pseudo regular expression:
// `linalg.{batch_}?matmul{_transpose_a | _transpose_b}?`
// have same number of ins/out, so its easy to stamp different versions.
template <typename NamedOpTy>
static LinalgOp replaceWithMatmulVariant(RewriterBase &rewriter, GenericOp op) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<NamedOpTy>(
op, ValueRange{op.getDpsInputs()[0], op.getDpsInputs()[1]},
ValueRange{op.getDpsInits()[0]});
return namedOp;
}

// Converts linalg.generic to named linalg.*matmul* where possible.
static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
GenericOp genericOp) {
if (genericOp.getNumDpsInputs() != 2 || genericOp.getNumDpsInits() != 1)
return failure();

// Early exit if not projected permutations.
auto mapRange = genericOp.getIndexingMapsArray();
if (llvm::any_of(mapRange,
[](AffineMap m) { return !m.isProjectedPermutation(); }))
return failure();

// Linalg generic contraction can be across multiple axis e.g.
// ```
// linalg.generic
// {indexing_maps = [affine_map<(m, n, k1, k2) -> (m, k1, k2)>,
// affine_map<(m, n, k1, k2) -> (k2, k1, n)>,
// affine_map<(m, n, k1, k2) -> (m, n)>],
// iterator_types = ["parallel", "parallel",
// "reduction", "reduction"]}
// ins(%A, %B : tensor<10x20x30xf32>, tensor<30x20x40xf32>)
// outs(%C : tensor<10x40xf32>) {
// ^bb0(%a: f32, %b: f32, %c: f32):
// %1 = arith.mulf %a, %b : f32
// %2 = arith.addf %c, %1 : f32
// linalg.yield %2 : f32
// } -> tensor<10x40xf32>
// ```
// In above contraction, there are two reduction dimensions {k1, k2}
// and although a valid linalg contraction, it is not a named-op
// matrix multiply kind. Therefore, reject multi-dim reduction.
auto res = inferContractionDims(genericOp);
if (!succeeded(res))
return failure();
auto dims = *res;
if (dims.m.size() != 1 || dims.n.size() != 1 || dims.k.size() != 1)
return failure();

if (!mlir::linalg::detail::isContractionBody(
*genericOp.getBlock(), [](Operation *first, Operation *second) {
if ((isa<arith::MulFOp>(first) && isa<arith::AddFOp>(second)) ||
(isa<arith::MulIOp>(first) && isa<arith::AddIOp>(second)) ||
(isa<complex::MulOp>(first) && isa<complex::AddOp>(second)))
return true;
return false;
}))
return failure();

// Check rank of operands
auto indexingMaps = genericOp.getIndexingMapsArray();
if (llvm::any_of(indexingMaps, [&dims](AffineMap m) {
return m.getResults().size() !=
dims.batch.size() + 2 /* any two of {m,n,k} */;
}))
return failure();

auto numOfBatchDims = dims.batch.size();
if (indexingMaps[0].getNumDims() != numOfBatchDims + 3)
return failure();

if (numOfBatchDims) {
// Each operand in a linalg generic contraction could express different
// permutations for its batch dimension. But for named op it must be
// identity since separate maps are not specified.
if (llvm::any_of(indexingMaps, [numOfBatchDims](AffineMap m) {
for (unsigned i = 0; i < numOfBatchDims; ++i) {
auto expr = m.getResults()[i];
if (expr.getKind() != AffineExprKind::DimId ||
cast<AffineDimExpr>(expr).getPosition() != i)
return true;
}
return false;
}))
return failure();
}

auto a =
matchOperandMap(indexingMaps[0], numOfBatchDims, dims.m[0], dims.k[0]);
auto b =
matchOperandMap(indexingMaps[1], numOfBatchDims, dims.k[0], dims.n[0]);
auto c =
matchOperandMap(indexingMaps[2], numOfBatchDims, dims.m[0], dims.n[0]);

if (llvm::any_of(ArrayRef<IndexMatchResult>{a, b, c}, [](IndexMatchResult r) {
return r == IndexMatchResult::Mismatch;
}))
return failure();

if (c != IndexMatchResult::Match ||
(a == IndexMatchResult::Transposed && b == IndexMatchResult::Transposed))
return failure();

/// Codegen the different matmul variants.
if (numOfBatchDims) {
if (a == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
genericOp);
if (b == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}

if (a == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
if (b == IndexMatchResult::Transposed)
return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}

} // namespace

//===----------------------------------------------------------------------===//
// Categorize linalg generic to named op where possible.
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
if (isaCopyOpInterface(genericOp)) {
Expand Down Expand Up @@ -100,5 +301,33 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
return namedOp;
}
}

if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
return failure();
}

namespace {
struct LinalgSpecializeGenericOpsPass
: public impl::LinalgSpecializeGenericOpsPassBase<
LinalgSpecializeGenericOpsPass> {

using impl::LinalgSpecializeGenericOpsPassBase<
LinalgSpecializeGenericOpsPass>::LinalgSpecializeGenericOpsPassBase;
void runOnOperation() override;
};
} // namespace

void LinalgSpecializeGenericOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLinalgGenericOpsSpecializationPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
}

void mlir::linalg::populateLinalgGenericOpsSpecializationPatterns(
RewritePatternSet &patterns) {
patterns.add<LinalgSpecializationPattern>(patterns.getContext());
}
52 changes: 52 additions & 0 deletions mlir/test/Dialect/Linalg/roundtrip-linalg-named-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// The following test examples of linalg named ops lowered to linalg.generic and then
// lifted back up to named op.
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s

func.func @unary_exp(%A: memref<7x14x21xf32>, %Out: memref<7x14x21xf32>) {
linalg.exp ins(%A : memref<7x14x21xf32>) outs(%Out : memref<7x14x21xf32>)
return
}

// CHECK-LABEL: unary_exp
// CHECK-SAME: %[[A:.+]]: memref<7x14x21xf32>, %[[Out:.+]]: memref<7x14x21xf32>)
// CHECK-NOT: linalg.generic
// CHECK: linalg.exp ins(%[[A]] : memref<7x14x21xf32>) outs(%[[Out]] : memref<7x14x21xf32>)

// -----

func.func @binary_add(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.add ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// CHECK-LABEL: binary_add
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.add ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>

// -----

func.func @matmul(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}

// CHECK-LABEL: @matmul
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>

// -----

func.func @mixed_named_ops(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>,
%C: tensor<?x?xf32>, %Out: tensor<?x?xf32>) -> tensor<?x?xf32> {
%AB = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = linalg.add ins(%AB, %C : tensor<?x?xf32>, tensor<?x?xf32>) outs(%Out : tensor<?x?xf32>) -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}

// CHECK-LABEL: @mixed_named_ops
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[B:.+]]: tensor<?x?xf32>, %[[C:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: %[[AB:.+]] = linalg.matmul ins(%[[A]], %[[B]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: linalg.add ins(%[[AB]], %[[C]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?xf32>) -> tensor<?x?xf32>
Loading
Loading