-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Linalg] Add more specialize patterns #91153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Currently only linalg.copy is recognized when trying to specialize linalg.generics back to named op. This diff enables recognition of more generic to named op e.g. linalg.fill, elemwise unary/binary.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Javed Absar (javedabsar1) ChangesCurrently only linalg.copy is recognized when trying to specialize linalg.generics back to named op. This diff enables recognition of more generic to named op e.g. linalg.fill, elemwise unary/binary. Full diff: https://github.com/llvm/llvm-project/pull/91153.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f92843a1dcb987..7a67525c1ba674 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
namespace linalg {
class IteratorTypeAttr;
class LinalgOp;
+class GenericOp;
namespace detail {
/// Implementation of the method that check if given operands
@@ -115,6 +116,17 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);
+/// Checks whether `genericOp` is semantically equivalent to a linalg
+// elementwise unary op e.g. linalg.exp.
+bool isaElementwiseUnaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a linalg
+// elementwise binary op e.g. linalg.sub.
+bool isaElementwiseBinaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
+bool isaFillOpInterface(GenericOp genericOp);
+
namespace detail {
/// Returns true if the block contains a contraction of the following form:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda3..e6611e496a4a2e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -70,6 +70,105 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
}
+//===----------------------------------------------------------------------===//
+// FillOpInterface implementation
+//===----------------------------------------------------------------------===//
+bool linalg::isaFillOpInterface(GenericOp genericOp) {
+ // Structural.
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
+ return false;
+
+ if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+ return false;
+
+ // Input should be referenced and init should not.
+ if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
+ genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+ return false;
+
+ OpOperand *value = genericOp.getDpsInputOperand(0);
+ if (!genericOp.isScalar(value))
+ return false;
+
+ Block *body = genericOp.getBody();
+ if (body->getOperations().size() != 1)
+ return false;
+
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+ yieldOp->getOperand(0) != body->getArgument(0))
+ return false;
+ return true;
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise-Unary/Binary-OpInterface implementation
+//===----------------------------------------------------------------------===//
+static bool isaElementwiseUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
+ unsigned arity) {
+ // Check all loops are parallel, and have only tensor semantics.
+ if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
+ genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+ return false;
+
+ // Check there are arity-inputs, 1-output and all are identity-maps.
+ if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
+ !llvm::all_of(genericOp.getIndexingMapsArray(),
+ [](AffineMap map) { return map.isIdentity(); }))
+ return false;
+
+ // Init should not be referenced for elementwise operations.
+ if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+ return false;
+
+ // Expect two ops: first one possibly unary/binary op and the second one must
+ // yield the nary-op result.
+ Block *body = genericOp.getBody();
+ if (body->getOperations().size() != 2)
+ return false;
+
+ Operation *op = &body->front();
+ if (op->getNumOperands() != arity || op->getNumResults() != 1)
+ return false;
+
+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+ if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+ yieldOp->getOperand(0).getDefiningOp() != op)
+ return false;
+ return true;
+}
+
+bool linalg::isaElementwiseUnaryOpInterface(linalg::GenericOp genericOp) {
+ // All basic elemwise checks.
+ if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 1))
+ return false;
+
+ // Check input is actully used.
+ if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
+ return false;
+ return true;
+}
+
+bool linalg::isaElementwiseBinaryOpInterface(linalg::GenericOp genericOp) {
+ if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 2))
+ return false;
+
+ // Check both inputs are used (elementwise).
+ OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
+ OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
+ if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
+ !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
+ return false;
+
+ // Check that args are not swapped (all elemwise ops are not commutative).
+ Block *body = genericOp.getBody();
+ Operation *op = &body->front();
+ if (op->getOpOperand(0).get() != body->getArgument(0) ||
+ op->getOpOperand(1).get() != body->getArgument(1))
+ return false;
+ return true;
+}
+
//===----------------------------------------------------------------------===//
// ContractionOpInterface implementation
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4c437b5db2c7b0..d3782287289a7b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -12,12 +12,25 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "linalg-specialization"
+#define REPLACE_BINARY_OP(NEWOP) \
+ (rewriter.replaceOpWithNewOp<NEWOP>( \
+ genericOp, \
+ ValueRange{genericOp.getDpsInputs()[0], genericOp.getDpsInputs()[1]}, \
+ ValueRange{genericOp.getDpsInits()[0]}))
+
+#define REPLACE_UNARY_OP(NEWOP) \
+ (rewriter.replaceOpWithNewOp<NEWOP>( \
+ genericOp, \
+ ValueRange{genericOp.getDpsInputs()[0]}, \
+ ValueRange{genericOp.getDpsInits()[0]}))
+
using namespace mlir;
using namespace mlir::linalg;
@@ -28,5 +41,39 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}
+
+ if (isaFillOpInterface(genericOp)) {
+ LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
+ genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+ return namedOp;
+ }
+
+ if (isaElementwiseUnaryOpInterface(genericOp)) {
+ Operation *op = &genericOp.getBody()->front();
+ if (isa<math::ExpOp>(op)) {
+ LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
+ return namedOp;
+ }
+ }
+
+ if (isaElementwiseBinaryOpInterface(genericOp)) {
+ Operation *op = &genericOp.getBody()->front();
+ if (isa<arith::AddFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(AddOp);
+ return namedOp;
+ }
+ if (isa<arith::SubFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(SubOp);
+ return namedOp;
+ }
+ if (isa<arith::MulFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(MulOp);
+ return namedOp;
+ }
+ if (isa<arith::DivFOp>(op)) {
+ LinalgOp namedOp = REPLACE_BINARY_OP(DivOp);
+ return namedOp;
+ }
+ }
return failure();
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8a22c115f31170..21dd1fb56789f2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -3,7 +3,6 @@
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1, d0)>
-
func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
@@ -141,3 +140,28 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1) -> ()>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%arg0 : tensor<7x7xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ linalg.yield %in : f32
+ } -> tensor<7x7xf32>
+ return %0 : tensor<7x7xf32>
+}
+// CHECK-LABEL: linalg_generic_fill
+// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
new file mode 100644
index 00000000000000..7bd3b1a1a4a4ca
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.addf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_add
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.subf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.mulf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_mul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+ ^bb0(%in: f32, %in_0: f32, %out: f32):
+ %1 = arith.divf %in, %in_0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
new file mode 100644
index 00000000000000..89a8baa453e905
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+ %0 = linalg.generic
+ {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+ ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+ ^bb0(%in: f32, %out: f32):
+ %1 = math.exp %in : f32
+ linalg.yield %1 : f32
+ } -> tensor<?x?x?xf32>
+ return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
//===----------------------------------------------------------------------===// | ||
// FillOpInterface implementation | ||
//===----------------------------------------------------------------------===// | ||
bool linalg::isaFillOpInterface(GenericOp genericOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be helpful instead of a bool
this returns a std::optional<Value>
that is the scalar value as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. thanks.
|
||
// Check there are arity-inputs, 1-output and all are identity-maps. | ||
if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 || | ||
!llvm::all_of(genericOp.getIndexingMapsArray(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is related to https://discourse.llvm.org/t/notes-from-the-mlir-upstream-round-table-eurollvm-2024/78374/11?u=maheshravishankar . Please correct me if I am wrong, but IMO this is too restrictive. It is perfectly reasonable for binary operations to have some "explicit broadcasting support". Is this already an assumption of these ops, or is this being added here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is related to https://discourse.llvm.org/t/notes-from-the-mlir-upstream-round-table-eurollvm-2024/78374/11?u=maheshravishankar . Please correct me if I am wrong, but IMO this is too restrictive. It is perfectly reasonable for binary operations to have some "explicit broadcasting support". Is this already an assumption of these ops, or is this being added here?
@MaheshRavishankar : Good point on broadcast. I hope I got your exact question right.
implicit broadcast is not supported by linalg.add implementation e.g.
= linalg.add ins(%arg0, %arg1 : tensor<10xf32>, tensor<10x100xf32>) outs(%arg2: tensor<10x100xf32>) -> tensor<10x100xf32>
error: 'linalg.add' op expected operand rank (1) to match the result rank of indexing_map #0 (2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Just a few more questions.
// Expect two ops: first one possibly unary/binary op and the second one must | ||
// yield the nary-op result. | ||
Block *body = genericOp.getBody(); | ||
if (body->getOperations().size() != 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like an unnecessary restriction. You could have an "elementwise operation" that cannot be a single instruction, but a sequence. SHouldnt matter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right, a truly isaElementwiseUnaryOp could be a sequence. Changed the API name to be more specific to context (isaElemwiseSingleUnaryOrBinaryOpInterface). As the objective here is raising to a single named op e.g. linalg.addrather than series of it. Actually come to think of it, probably un-fuse followed by generic->named is the way rather than unthreading it all here.
Not so much for this diff, but for binary-op the elementwise semantics is more interesting -
%add1 = arith.add %0, %1 : f32
%sub= arith.sub%2, %3 : f32
versus
%add1 = arith.add %0, %1 : f32
%sub= arith.sub%add1, %3 : f32
Former is more like resulting from sibling-fusion while latter producer-consumer. Both lead to more than two InputOperands required and then one wonders whether its really a 'binary' op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, thanks!
// Check that args are not swapped (all elemwise ops are not commutative). | ||
Block *body = genericOp.getBody(); | ||
Operation *op = &body->front(); | ||
if (op->getOpOperand(0).get() != body->getArgument(0) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not clear what this is doing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. This was bit cryptic. Moved it to Specialize.cpp where named op is being created. Added explanatory comment there. thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, this looks fine to me.
// Expect two ops: first one possibly unary/binary op and the second one must | ||
// yield the nary-op result. | ||
Block *body = genericOp.getBody(); | ||
if (body->getOperations().size() != 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, thanks!
Currently only linalg.copy is recognized when trying to specialize linalg.generics back to named op. This diff enables recognition of more generic to named op e.g. linalg.fill, elemwise unary/binary.