Skip to content

[MLIR][Linalg] Add more specialize patterns #91153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace mlir {
namespace linalg {
class IteratorTypeAttr;
class LinalgOp;
class GenericOp;

namespace detail {
/// Implementation of the method that check if given operands
Expand Down Expand Up @@ -115,6 +116,21 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);

/// Checks whether a given `genericOp` is semantically equivalent to a single
/// linalgelementwise unary op. e.g. linalg.exp.
/// A linalg.generic body could be a series of unary elementwise ops e.g.
/// `exp(neg(x))`, such as formed by linalg op fusion. Here we restrict it to
/// detecting cases where body is is a single computation op.
bool isaElemwiseSingleUnaryOpInterface(GenericOp genericOp);

/// Checks whether `genericOp` is semantically equivalent to a single linalg
/// elementwise binary op e.g. linalg.sub.
bool isaElemwiseSingleBinaryOpInterface(GenericOp genericOp);

/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
/// Returns the scalar fill value if true.
std::optional<Value> isaFillOpInterface(GenericOp genericOp);

namespace detail {

/// Returns true if the block contains a contraction of the following form:
Expand Down
93 changes: 93 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,99 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
}

//===----------------------------------------------------------------------===//
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
// Structural.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
return std::nullopt;

// Input should be referenced and init should not.
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
return std::nullopt;

OpOperand *value = genericOp.getDpsInputOperand(0);
if (!genericOp.isScalar(value))
return std::nullopt;

Block *body = genericOp.getBody();
if (body->getOperations().size() != 1)
return std::nullopt;

auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
yieldOp->getOperand(0) != body->getArgument(0))
return std::nullopt;
return value->get();
}

//===----------------------------------------------------------------------===//
// Elementwise Single Unary/Binary-OpInterface implementation
//===----------------------------------------------------------------------===//
static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel, and have only tensor semantics.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
return false;

// Check there are arity-inputs, 1-output and all are identity-maps.
if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
!llvm::all_of(genericOp.getIndexingMapsArray(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is related to https://discourse.llvm.org/t/notes-from-the-mlir-upstream-round-table-eurollvm-2024/78374/11?u=maheshravishankar . Please correct me if I am wrong, but IMO this is too restrictive. It is perfectly reasonable for binary operations to have some "explicit broadcasting support". Is this already an assumption of these ops, or is this being added here?

Copy link
Contributor Author

@javedabsar1 javedabsar1 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is related to https://discourse.llvm.org/t/notes-from-the-mlir-upstream-round-table-eurollvm-2024/78374/11?u=maheshravishankar . Please correct me if I am wrong, but IMO this is too restrictive. It is perfectly reasonable for binary operations to have some "explicit broadcasting support". Is this already an assumption of these ops, or is this being added here?

@MaheshRavishankar : Good point on broadcast. I hope I got your exact question right.
implicit broadcast is not supported by linalg.add implementation e.g.
= linalg.add ins(%arg0, %arg1 : tensor<10xf32>, tensor<10x100xf32>) outs(%arg2: tensor<10x100xf32>) -> tensor<10x100xf32>
error: 'linalg.add' op expected operand rank (1) to match the result rank of indexing_map #0 (2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks.

[](AffineMap map) { return map.isIdentity(); }))
return false;

// Init should not be referenced for elementwise operations.
if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
return false;

// A linalg.generic could be series of elementwise ops e.g. exp(neg(x)) such
// as resulting from producer-consumer fusion. Here, we restrict to two ops in
// the body, where the first is the elementwise single op and the second a
// yield.
Block *body = genericOp.getBody();
if (body->getOperations().size() != 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an unnecessary restriction. You could have an "elementwise operation" that cannot be a single instruction, but a sequence. SHouldnt matter.

Copy link
Contributor Author

@javedabsar1 javedabsar1 May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, a truly isaElementwiseUnaryOp could be a sequence. Changed the API name to be more specific to context (isaElemwiseSingleUnaryOrBinaryOpInterface). As the objective here is raising to a single named op e.g. linalg.addrather than series of it. Actually come to think of it, probably un-fuse followed by generic->named is the way rather than unthreading it all here.

Not so much for this diff, but for binary-op the elementwise semantics is more interesting -
%add1 = arith.add %0, %1 : f32
%sub= arith.sub%2, %3 : f32
versus
%add1 = arith.add %0, %1 : f32
%sub= arith.sub%add1, %3 : f32
Former is more like resulting from sibling-fusion while latter producer-consumer. Both lead to more than two InputOperands required and then one wonders whether its really a 'binary' op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, thanks!

return false;

Operation *op = &body->front();
if (op->getNumOperands() != arity || op->getNumResults() != 1)
return false;

auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
yieldOp->getOperand(0).getDefiningOp() != op)
return false;
return true;
}

bool linalg::isaElemwiseSingleUnaryOpInterface(linalg::GenericOp genericOp) {
// All basic elemwise checks.
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 1))
return false;

// Check input is actully used.
if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
return false;
return true;
}

bool linalg::isaElemwiseSingleBinaryOpInterface(linalg::GenericOp genericOp) {
if (!isaElemwiseSingleUnaryOrBinaryOpInterface(genericOp, 2))
return false;

// Check both inputs are used (elementwise).
OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
!genericOp.payloadUsesValueFromOperand(inputOpOperand1))
return false;
return true;
}

//===----------------------------------------------------------------------===//
// ContractionOpInterface implementation
//===----------------------------------------------------------------------===//
Expand Down
72 changes: 72 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,91 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "llvm/Support/Debug.h"

#define DEBUG_TYPE "linalg-specialization"

#define REPLACE_BINARY_OP(NEWOP, OPERANDS_SWAP) \
(rewriter.replaceOpWithNewOp<NEWOP>( \
genericOp, \
ValueRange{genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 1 : 0], \
genericOp.getDpsInputs()[(OPERANDS_SWAP) ? 0 : 1]}, \
ValueRange{genericOp.getDpsInits()[0]}))

#define REPLACE_UNARY_OP(NEWOP) \
(rewriter.replaceOpWithNewOp<NEWOP>(genericOp, \
ValueRange{genericOp.getDpsInputs()[0]}, \
ValueRange{genericOp.getDpsInits()[0]}))

using namespace mlir;
using namespace mlir::linalg;

// Given a elementwise single binary linalg generic op, checks whether the
// binary op accesses operands as swapped. e.g.
// this differentiates between a linalg-generic body that contains:
// ^bb0(%a: f32, %b: f32, %c : f32):
// %0 = arith.subf %a, %b : f32
// linalg.yield %0: f32
// against:
// ^bb0(%a: f32, %b: f32, %c : f32):
// %0 = arith.subf %b, %a : f32
// linalg.yield %0: f32
// Former is linalg.sub(a,b), latter is linalg.sub(b,a).
static bool areBinOpsSwapped(GenericOp genericOp) {
Block *body = genericOp.getBody();
Operation *op = &body->front();
bool swapped = false;
if (op->getOpOperand(0).get() != body->getArgument(0)) {
swapped = true;
assert(op->getOpOperand(0).get() == body->getArgument(1) &&
op->getOpOperand(1).get() == body->getArgument(0) &&
"binary op uses just one block arg");
}
return swapped;
}

FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
if (isaCopyOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}

if (isaFillOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}

if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
Operation *op = &genericOp.getBody()->front();
if (isa<math::ExpOp>(op)) {
LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
return namedOp;
}
}

if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
bool swap = areBinOpsSwapped(genericOp);
Operation *op = &genericOp.getBody()->front();
if (isa<arith::AddFOp>(op)) {
LinalgOp namedOp = REPLACE_BINARY_OP(AddOp, swap);
return namedOp;
}
if (isa<arith::SubFOp>(op)) {
LinalgOp namedOp = REPLACE_BINARY_OP(SubOp, swap);
return namedOp;
}
if (isa<arith::MulFOp>(op)) {
LinalgOp namedOp = REPLACE_BINARY_OP(MulOp, swap);
return namedOp;
}
if (isa<arith::DivFOp>(op)) {
LinalgOp namedOp = REPLACE_BINARY_OP(DivOp, swap);
return namedOp;
}
}
return failure();
}
25 changes: 25 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-specialize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,28 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

#map = affine_map<(d0, d1) -> ()>
#map1 = affine_map<(d0, d1) -> (d0, d1)>
func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
%cst = arith.constant 0.000000e+00 : f32
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%arg0 : tensor<7x7xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<7x7xf32>
return %0 : tensor<7x7xf32>
}
// CHECK-LABEL: linalg_generic_fill
// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.addf %in, %in_0 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: specialize_add
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>

func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.subf %in, %in_0 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: specialize_sub
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>

func.func @specialize_sub_swapped_operands(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.subf %in_0, %in : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: specialize_sub
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.sub ins(%[[ARG1]], %[[ARG0]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>

func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.mulf %in, %in_0 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: specialize_mul
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>

func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%1 = arith.divf %in, %in_0 : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: specialize_div
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>


module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s

#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%0 = linalg.generic
{indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
^bb0(%in: f32, %out: f32):
%1 = math.exp %in : f32
linalg.yield %1 : f32
} -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// CHECK-LABEL: specialize_exp
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// CHECK-NOT: linalg.generic
// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
Loading