Skip to content

[MLIR][Linalg] Add more specialize patterns #91153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 20, 2024

Conversation

javedabsar1
Copy link
Contributor

Currently only linalg.copy is recognized when trying to specialize linalg.generics back to named op. This diff enables recognition of more generic to named op e.g. linalg.fill, elemwise unary/binary.

Currently only linalg.copy is recognized when trying to specialize
linalg.generics back to named op. This diff enables recognition
of more generic to named op e.g. linalg.fill, elemwise unary/binary.
@llvmbot
Copy link
Member

llvmbot commented May 5, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Javed Absar (javedabsar1)

Changes

Currently only linalg.copy is recognized when trying to specialize linalg.generics back to named op. This diff enables recognition of more generic to named op e.g. linalg.fill, elemwise unary/binary.


Full diff: https://github.com/llvm/llvm-project/pull/91153.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h (+12)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (+99)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp (+47)
  • (modified) mlir/test/Dialect/Linalg/transform-op-specialize.mlir (+25-1)
  • (added) mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir (+63)
  • (added) mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir (+25)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index f92843a1dcb987..7a67525c1ba674 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -28,6 +28,7 @@ namespace mlir {
 namespace linalg {
 class IteratorTypeAttr;
 class LinalgOp;
+class GenericOp;
 
 namespace detail {
 /// Implementation of the method that check if given operands
@@ -115,6 +116,17 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp);
 /// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
 bool isaCopyOpInterface(LinalgOp linalgOp);
 
+/// Checks whether `genericOp` is semantically equivalent to a linalg
+// elementwise unary op e.g. linalg.exp.
+bool isaElementwiseUnaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a linalg
+// elementwise binary op e.g. linalg.sub.
+bool isaElementwiseBinaryOpInterface(GenericOp genericOp);
+
+/// Checks whether `genericOp` is semantically equivalent to a `linalg.fill`.
+bool isaFillOpInterface(GenericOp genericOp);
+
 namespace detail {
 
 /// Returns true if the block contains a contraction of the following form:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda3..e6611e496a4a2e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -70,6 +70,105 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
   return llvm::hasSingleElement(linalgOp.getBlock()->getOperations());
 }
 
+//===----------------------------------------------------------------------===//
+// FillOpInterface implementation
+//===----------------------------------------------------------------------===//
+bool linalg::isaFillOpInterface(GenericOp genericOp) {
+  // Structural.
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
+     return false;
+
+  if (genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
+    return false;
+
+  // Input should be referenced and init should not.
+  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)) ||
+       genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+    return false;
+
+  OpOperand *value = genericOp.getDpsInputOperand(0);
+  if (!genericOp.isScalar(value))
+    return false;
+
+  Block *body = genericOp.getBody();
+  if (body->getOperations().size() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0) != body->getArgument(0))
+    return false;
+  return true;
+}
+
+//===----------------------------------------------------------------------===//
+// Elementwise-Unary/Binary-OpInterface implementation
+//===----------------------------------------------------------------------===//
+static bool isaElementwiseUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
+                                                   unsigned arity) {
+  // Check all loops are parallel, and have only tensor semantics.
+  if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
+      genericOp.getNumLoops() < 1 || !genericOp.hasPureTensorSemantics())
+    return false;
+
+  // Check there are arity-inputs, 1-output and all are identity-maps.
+  if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
+      !llvm::all_of(genericOp.getIndexingMapsArray(),
+                    [](AffineMap map) { return map.isIdentity(); }))
+    return false;
+
+  // Init should not be referenced for elementwise operations.
+  if (genericOp.payloadUsesValueFromOperand(genericOp.getDpsInitOperand(0)))
+    return false;
+
+  // Expect two ops: first one possibly unary/binary op and the second one must
+  // yield the nary-op result.
+  Block *body = genericOp.getBody();
+  if (body->getOperations().size() != 2)
+    return false;
+
+  Operation *op = &body->front();
+  if (op->getNumOperands() != arity || op->getNumResults() != 1)
+    return false;
+
+  auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
+  if (!yieldOp || yieldOp.getNumOperands() != 1 ||
+      yieldOp->getOperand(0).getDefiningOp() != op)
+    return false;
+  return true;
+}
+
+bool linalg::isaElementwiseUnaryOpInterface(linalg::GenericOp genericOp) {
+  // All basic elemwise checks.
+  if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 1))
+    return false;
+
+  // Check input is actully used.
+  if (!genericOp.payloadUsesValueFromOperand(genericOp.getDpsInputOperand(0)))
+    return false;
+  return true;
+}
+
+bool linalg::isaElementwiseBinaryOpInterface(linalg::GenericOp genericOp) {
+  if (!isaElementwiseUnaryOrBinaryOpInterface(genericOp, 2))
+    return false;
+
+  // Check both inputs are used (elementwise).
+  OpOperand *inputOpOperand0 = genericOp.getDpsInputOperand(0);
+  OpOperand *inputOpOperand1 = genericOp.getDpsInputOperand(1);
+  if (!genericOp.payloadUsesValueFromOperand(inputOpOperand0) ||
+      !genericOp.payloadUsesValueFromOperand(inputOpOperand1))
+    return false;
+
+  // Check that args are not swapped (all elemwise ops are not commutative).
+  Block *body = genericOp.getBody();
+  Operation *op = &body->front();
+  if (op->getOpOperand(0).get() != body->getArgument(0) ||
+      op->getOpOperand(1).get() != body->getArgument(1))
+    return false;
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // ContractionOpInterface implementation
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 4c437b5db2c7b0..d3782287289a7b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -12,12 +12,25 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "linalg-specialization"
 
+#define REPLACE_BINARY_OP(NEWOP)                                               \
+  (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
+      genericOp,                                                               \
+      ValueRange{genericOp.getDpsInputs()[0], genericOp.getDpsInputs()[1]},    \
+      ValueRange{genericOp.getDpsInits()[0]}))
+
+#define REPLACE_UNARY_OP(NEWOP)                                                \
+  (rewriter.replaceOpWithNewOp<NEWOP>(                                         \
+      genericOp,                                                               \
+      ValueRange{genericOp.getDpsInputs()[0]},                                 \
+      ValueRange{genericOp.getDpsInits()[0]}))
+
 using namespace mlir;
 using namespace mlir::linalg;
 
@@ -28,5 +41,39 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
         genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
     return namedOp;
   }
+
+  if (isaFillOpInterface(genericOp)) {
+    LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
+        genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
+    return namedOp;
+  }
+
+  if (isaElementwiseUnaryOpInterface(genericOp)) {
+    Operation *op = &genericOp.getBody()->front();
+    if (isa<math::ExpOp>(op)) {
+      LinalgOp namedOp = REPLACE_UNARY_OP(ExpOp);
+      return namedOp;
+    }
+  }
+
+  if (isaElementwiseBinaryOpInterface(genericOp)) {
+    Operation *op = &genericOp.getBody()->front();
+    if (isa<arith::AddFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(AddOp);
+      return namedOp;
+    }
+    if (isa<arith::SubFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(SubOp);
+      return namedOp;
+    }
+    if (isa<arith::MulFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(MulOp);
+      return namedOp;
+    }
+    if (isa<arith::DivFOp>(op)) {
+      LinalgOp namedOp = REPLACE_BINARY_OP(DivOp);
+      return namedOp;
+    }
+  }
   return failure();
 }
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
index 8a22c115f31170..21dd1fb56789f2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize.mlir
@@ -3,7 +3,6 @@
 #map = affine_map<(d0, d1) -> (d0, d1)>
 #map1 = affine_map<(d0, d1) -> (d0)>
 #map2 = affine_map<(d0, d1) -> (d1, d0)>
-
 func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
   // expected-note @below {{when applied to this op}}
   linalg.generic {
@@ -141,3 +140,28 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> ()>
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @linalg_generic_fill(%arg0: tensor<7x7xf32>) -> tensor<7x7xf32> {
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel"]} ins(%cst : f32) outs(%arg0 : tensor<7x7xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    linalg.yield %in : f32
+  } -> tensor<7x7xf32>
+  return %0 : tensor<7x7xf32>
+}
+// CHECK-LABEL: linalg_generic_fill
+// CHECK-SAME: %[[ARG0:.+]]: tensor<7x7xf32>) -> tensor<7x7xf32>
+// CHECK:  %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %{{.*}} = linalg.fill ins(%[[CST]] : f32) outs(%[[ARG0]] : tensor<7x7xf32>) -> tensor<7x7xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
new file mode 100644
index 00000000000000..7bd3b1a1a4a4ca
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_binary.mlir
@@ -0,0 +1,63 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @specialize_add(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.addf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_add
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.add ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_sub(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.subf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_sub
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.sub ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_mul(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.mulf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_mul
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.mul ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+func.func @specialize_div(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg2 : tensor<?x?xf32>) {
+  ^bb0(%in: f32, %in_0: f32, %out: f32):
+    %1 = arith.divf %in, %in_0 : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+// CHECK-LABEL: specialize_div
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>, %[[ARG1:.+]]: tensor<?x?xf32>,  %[[ARG2:.+]]: tensor<?x?xf32>) -> tensor<?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.div ins(%[[ARG0]], %[[ARG1]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ARG2]] : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
new file mode 100644
index 00000000000000..89a8baa453e905
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/transform-op-specialize_elemwise_unary.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt --transform-interpreter --split-input-file --verify-diagnostics %s | FileCheck %s
+
+#umap = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @specialize_exp(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
+  %0 = linalg.generic
+          {indexing_maps = [#umap, #umap], iterator_types = ["parallel", "parallel","parallel"]}
+          ins(%arg0 : tensor<?x?x?xf32>) outs(%arg1 : tensor<?x?x?xf32>) {
+  ^bb0(%in: f32, %out: f32):
+    %1 = math.exp %in : f32
+    linalg.yield %1 : f32
+  } -> tensor<?x?x?xf32>
+  return %0 : tensor<?x?x?xf32>
+}
+// CHECK-LABEL: specialize_exp
+// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+// CHECK-NOT: linalg.generic
+// CHECK: linalg.exp ins(%[[ARG0]] : tensor<?x?x?xf32>) outs(%[[ARG1]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match interface{LinalgOp} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.structured.specialize %0 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}

Copy link

github-actions bot commented May 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

//===----------------------------------------------------------------------===//
// FillOpInterface implementation
//===----------------------------------------------------------------------===//
bool linalg::isaFillOpInterface(GenericOp genericOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful instead of a bool this returns a std::optional<Value> that is the scalar value as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. thanks.


// Check there are arity-inputs, 1-output and all are identity-maps.
if (genericOp.getNumDpsInputs() != arity || genericOp.getNumDpsInits() != 1 ||
!llvm::all_of(genericOp.getIndexingMapsArray(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is related to https://discourse.llvm.org/t/notes-from-the-mlir-upstream-round-table-eurollvm-2024/78374/11?u=maheshravishankar . Please correct me if I am wrong, but IMO this is too restrictive. It is perfectly reasonable for binary operations to have some "explicit broadcasting support". Is this already an assumption of these ops, or is this being added here?

Copy link
Contributor Author

@javedabsar1 javedabsar1 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is related to https://discourse.llvm.org/t/notes-from-the-mlir-upstream-round-table-eurollvm-2024/78374/11?u=maheshravishankar . Please correct me if I am wrong, but IMO this is too restrictive. It is perfectly reasonable for binary operations to have some "explicit broadcasting support". Is this already an assumption of these ops, or is this being added here?

@MaheshRavishankar : Good point on broadcast. I hope I got your exact question right.
implicit broadcast is not supported by linalg.add implementation e.g.
= linalg.add ins(%arg0, %arg1 : tensor<10xf32>, tensor<10x100xf32>) outs(%arg2: tensor<10x100xf32>) -> tensor<10x100xf32>
error: 'linalg.add' op expected operand rank (1) to match the result rank of indexing_map #0 (2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks.

@joker-eph joker-eph changed the title [MLIR][LINALG] Add more specialize patterns [MLIR][Linalg] Add more specialize patterns May 6, 2024
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Just a few more questions.

// Expect two ops: first one possibly unary/binary op and the second one must
// yield the nary-op result.
Block *body = genericOp.getBody();
if (body->getOperations().size() != 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an unnecessary restriction. You could have an "elementwise operation" that cannot be a single instruction, but a sequence. SHouldnt matter.

Copy link
Contributor Author

@javedabsar1 javedabsar1 May 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, a truly isaElementwiseUnaryOp could be a sequence. Changed the API name to be more specific to context (isaElemwiseSingleUnaryOrBinaryOpInterface). As the objective here is raising to a single named op e.g. linalg.addrather than series of it. Actually come to think of it, probably un-fuse followed by generic->named is the way rather than unthreading it all here.

Not so much for this diff, but for binary-op the elementwise semantics is more interesting -
%add1 = arith.add %0, %1 : f32
%sub= arith.sub%2, %3 : f32
versus
%add1 = arith.add %0, %1 : f32
%sub= arith.sub%add1, %3 : f32
Former is more like resulting from sibling-fusion while latter producer-consumer. Both lead to more than two InputOperands required and then one wonders whether its really a 'binary' op.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, thanks!

// Check that args are not swapped (all elemwise ops are not commutative).
Block *body = genericOp.getBody();
Operation *op = &body->front();
if (op->getOpOperand(0).get() != body->getArgument(0) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear what this is doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. This was bit cryptic. Moved it to Specialize.cpp where named op is being created. Added explanatory comment there. thanks.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, this looks fine to me.

// Expect two ops: first one possibly unary/binary op and the second one must
// yield the nary-op result.
Block *body = genericOp.getBody();
if (body->getOperations().size() != 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, thanks!

@javedabsar1 javedabsar1 merged commit 33b7833 into llvm:main May 20, 2024
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants