Skip to content

[mlir][linalg] raise generic to named ops. #110421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp,
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);

/// Checks whether `genericOp` is semantically equivalent to a
/// `linalg.broadcast`. Returns broadcast dimensions if true.
std::optional<SmallVector<int64_t>>
isaBroadcastOpInterface(GenericOp genericOp);

/// Checks whether `genericOp` is semantically equivalent to a
/// `linalg.transpose`. Returns permuted dimensions if true.
std::optional<SmallVector<int64_t>>
isaTransposeOpInterface(GenericOp genericOp);

/// Checks whether a given `genericOp` is semantically equivalent to a single
/// linalgelementwise unary op. e.g. linalg.exp.
/// A linalg.generic body could be a series of unary elementwise ops e.g.
Expand Down
24 changes: 24 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,18 @@ def LinalgStructuredInterface
utils::IteratorType::parallel);
}]
>,
InterfaceMethod<
/*desc=*/[{
Return true if all loops are parallel.
}],
/*retTy=*/"bool",
/*methodName=*/"isAllParallelLoops",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumParallelLoops() == getNumParallelLoops();
}]
>,
InterfaceMethod<
/*desc=*/[{
Return the dims that are parallel loops.
Expand Down Expand Up @@ -327,6 +339,18 @@ def LinalgStructuredInterface
return !getBlock()->getArgument(bbArgNumber).use_empty();
}]
>,
InterfaceMethod<
/*desc=*/[{
Returns true only if linalgOp takes one input and produces one result.
}],
/*retTy=*/"bool",
/*methodName=*/"isSingleInputOutput",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return $_op.getNumDpsInputs() == 1 && $_op.getNumDpsInits() == 1;
}]
>,
InterfaceMethod<
/*desc=*/[{
Return true if `opOperand` is an init tensor. This is true when it is
Expand Down
18 changes: 18 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,24 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
}

MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }

// Return true only if GenericOp has a single input and single
// output, and the body is a single yieldOp that yields the input.
// This check is useful when trying to determine if the op is
// essentially a transpose, broadcast, copy or something like that.
bool isSingleYieldOp() {
if (!isSingleInputOutput())
return false;
Block *body = getBody();
if (body->getOperations().size() != 1)
return false;

auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
yieldOp->getOperand(0) != body->getArgument(0))
return false;
return true;
}
}];

let hasCanonicalizer = 1;
Expand Down
153 changes: 111 additions & 42 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <numeric>

using namespace mlir;
using namespace mlir::linalg;
Expand Down Expand Up @@ -53,112 +54,180 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
// CopyOpInterface implementation
//===----------------------------------------------------------------------===//

bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
// Structural.
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
bool linalg::isaCopyOpInterface(LinalgOp op) {
// Check all loops are parallel and linalgOp is single input and output.
if (!op.isAllParallelLoops() || !op.isSingleInputOutput())
return false;

// Operands and maps.
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
return false;
auto mapRange = linalgOp.getIndexingMapsArray();
auto mapRange = op.getIndexingMapsArray();
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
!mapRange.back().isIdentity()) {
return false;
}
// Region.
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
return llvm::hasSingleElement(op.getBlock()->getOperations());
}

//===----------------------------------------------------------------------===//
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
// Structural.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
!op.isSingleYieldOp())
return std::nullopt;

// Input should be referenced and init should not.
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)) ||
op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
return std::nullopt;

OpOperand *value = genericOp.getDpsInputOperand(0);
if (!genericOp.isScalar(value))
OpOperand *value = op.getDpsInputOperand(0);
if (!op.isScalar(value))
return std::nullopt;
return value->get();
}

Block *body = genericOp.getBody();
if (body->getOperations().size() != 1)
//===----------------------------------------------------------------------===//
// BroadcastOpInterface implementation
//===----------------------------------------------------------------------===//
std::optional<SmallVector<int64_t>>
linalg::isaBroadcastOpInterface(GenericOp op) {
// Structural.
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
!op.isSingleYieldOp())
return std::nullopt;

auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
yieldOp->getOperand(0) != body->getArgument(0))
auto srcTy = op.getDpsInputOperand(0)->get().getType();
auto dstTy = op.getDpsInitOperand(0)->get().getType();
if (!isa<MemRefType, RankedTensorType>(srcTy) ||
!isa<MemRefType, RankedTensorType>(dstTy))
return std::nullopt;
return value->get();

// Check output is identity map. Broadcast could additionally be
// employing permutation of indices and that would be expressible
// in linalg.generic but is not expressible for named broadcast op.
auto dstMap = op.getIndexingMapsArray()[1];
if (!dstMap.isIdentity())
return std::nullopt;

SmallVector<int64_t> position;
auto srcMap = op.getIndexingMapsArray()[0];

if (srcMap.getResults().size() >= dstMap.getResults().size())
return std::nullopt;

// Check input map is monotonically increasing DimIds.
for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
if (!expr)
return std::nullopt;
int64_t pos = expr.getPosition();
if (i > 0 && pos <= position[i - 1])
return std::nullopt;
position.push_back(expr.getPosition());
}

SmallVector<int64_t> broadcastedDims;
auto numDims = srcMap.getNumDims();
// This is quadratic but number of items is generally small.
for (auto dim : llvm::seq<int64_t>(0, numDims)) {
if (!llvm::is_contained(position, dim))
broadcastedDims.push_back(dim);
}
return broadcastedDims;
}

//===----------------------------------------------------------------------===//
// TranposeOpInterface implementation
//===----------------------------------------------------------------------===//
std::optional<SmallVector<int64_t>>
linalg::isaTransposeOpInterface(GenericOp op) {
// To specialize as a transpose op, the genericOp must be
// all parallel loops, single input, single output, and its body
// should be just a yield op, yielding input as output as is (no compute).
if (!op.isAllParallelLoops() || !op.isSingleInputOutput() ||
!op.isSingleYieldOp())
return std::nullopt;

auto mapRange = op.getIndexingMapsArray();
if (mapRange.size() != 2)
return std::nullopt;

auto mapOfInput = mapRange.front();
auto mapOfResult = mapRange.back();

// linalg.transpose permutes the dimensions of input using this
// rule: dim(result, i) = dim(input, permutation[i])
if (!mapOfResult.isIdentity() || !mapOfInput.isPermutation())
return std::nullopt;

SmallVector<int64_t> permutation(mapOfInput.getNumDims());
for (unsigned i = 0; i < mapOfInput.getNumDims(); ++i) {
auto expr = llvm::cast<AffineDimExpr>(mapOfInput.getResults()[i]);
permutation[expr.getPosition()] = i;
}
return permutation;
}

//===----------------------------------------------------------------------===//
// Elementwise Single Unary/Binary-OpInterface implementation
//===----------------------------------------------------------------------===//
static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
static bool isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp op,
unsigned arity) {
// Check all loops are parallel.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumLoops() < 1)
if (!op.isAllParallelLoops() || op.getNumLoops() < 1)
return false;

// Check there are arity-inputs, 1-output and all are identity-maps.
if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
!llvm::all_of(genericOp.getIndexingMapsArray(),
if (op.getNumDpsInputs() != arity || op.getNumDpsInits() != 1 ||
!llvm::all_of(op.getIndexingMapsArray(),
[](AffineMap map) { return map.isIdentity(); }))
return false;

// Init should not be referenced for elementwise operations.
if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
if (op.payloadUsesValueFromOperand(op.getDpsInitOperand(0)))
return false;

// A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
// as resulting from producer-consumer fusion. Here, we restrict to two ops in
// the body, where the first is the elementwise single op and the second a
// yield.
Block *body = genericOp.getBody();
Block *body = op.getBody();
if (body->getOperations().size() != 2)
return false;

Operation *op = &body->front();
if (op->getNumOperands() != arity || op->getNumResults() != 1)
Operation *oper = &body->front();
if (oper->getNumOperands() != arity || oper->getNumResults() != 1)
return false;

auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
yieldOp->getOperand(0).getDefiningOp() != op)
yieldOp->getOperand(0).getDefiningOp() != oper)
return false;
return true;
}

bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp op) {
// All basic elemwise checks.
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1))
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 1))
return false;

// Check input is actully used.
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
if (!op.payloadUsesValueFromOperand(op.getDpsInputOperand(0)))
return false;
return true;
}

bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2))
bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp op) {
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(op, 2))
return false;

// Check both inputs are used (elementwise).
OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
!genericOp.payloadUsesValueFromOperand(inputOpOperand1))
OpOperand *inputOpOperand0 = op.getDpsInputOperand(0);
OpOperand *inputOpOperand1 = op.getDpsInputOperand(1);
if (!op.payloadUsesValueFromOperand(inputOpOperand0) ||
!op.payloadUsesValueFromOperand(inputOpOperand1))
return false;
return true;
}
Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,18 +259,43 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
// Copy
if (isaCopyOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}

// Fill
if (isaFillOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}

// Broadcast
std::optional<SmallVector<int64_t>> equivalentToBroadcast =
isaBroadcastOpInterface(genericOp);
if (equivalentToBroadcast) {
auto dims = *equivalentToBroadcast;
LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
dims);
return namedOp;
}

// Transpose
std::optional<SmallVector<int64_t>> equivalentToTranspose =
isaTransposeOpInterface(genericOp);
if (equivalentToTranspose) {
auto permutation = *equivalentToTranspose;
LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
permutation);
return namedOp;
}

// Elementwise Unary
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
Operation *op = &genericOp.getBody()->front();
if (isa<math::ExpOp>(op)) {
Expand All @@ -279,6 +304,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}

// Elementwise Binary
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
bool swap = areBinOpsSwapped(genericOp);
Operation *op = &genericOp.getBody()->front();
Expand All @@ -300,6 +326,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}

// Contraction - e.g. matmul
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
Expand Down
Loading
Loading