Skip to content

[mlir][transform] Add elementwise criteria to match.structured.body #79626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 31, 2024

Conversation

srcarroll
Copy link
Contributor

@srcarroll srcarroll commented Jan 26, 2024

As far as I am aware, there is no simple way to match on elementwise ops. I propose to add an elementwise criteria to the match.structured.body op. Although my only hesitation is that elementwise is not only determined by the body, but also the indexing maps. So if others find this too awkward, I can implement a separate match op instead.

@chelini chelini requested review from ftynse and removed request for ftynse January 26, 2024 20:27
@chelini
Copy link
Contributor

chelini commented Jan 26, 2024

Is it ready for review? Is marked as "draft".

@srcarroll srcarroll marked this pull request as ready for review January 26, 2024 21:05
@llvmbot
Copy link
Member

llvmbot commented Jan 26, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (srcarroll)

Changes

As far as I am aware, there is no simple way to match on elementwise ops. I propose to add an elementwise criteria to the match.structured.body op. Although my only hesitation is that elementwise is not only determined by the body, but also the indexing maps. So if others find this too awkward, I can implement a separate match op instead.


Full diff: https://github.com/llvm/llvm-project/pull/79626.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td (+4)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp (+8-1)
  • (modified) mlir/test/Dialect/Linalg/match-ops-interpreter.mlir (+57)
  • (modified) mlir/test/Dialect/Linalg/match-ops-invalid.mlir (+1-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index 162dd05f93030f..dfeb8ae5d5ddbc 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -106,6 +106,9 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
       * `passthrough`: the body of the structured payload op only forwards
         inputs to the outputs (copy or broadcast).
 
+      * `elementwise`: the body of the structured payload op represents an
+        elementwise operation.
+
       * `contraction`: the body of the structured payload op is a contraction
         of the form `<red>(<elem>(bbarg0, bbarg1), bbarg2)` where `<elem>` and
         `<red>` are binary operations whose names are specified in the attribute
@@ -123,6 +126,7 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
   let arguments = (ins TransformHandleTypeInterface:$operand_handle,
                        OptionalAttr<I64Attr>:$reduction_position,
                        UnitAttr:$passthrough,
+                       UnitAttr:$elementwise,
                        OptionalAttr<StrArrayAttr>:$contraction);
   let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle)";
   let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 115da4b90e063a..fb18886c16b16d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
 #include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
@@ -187,6 +188,11 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
     }
     return DiagnosedSilenceableFailure::success();
   }
+  if (getElementwise()) {
+    if (!isElementwise(linalgOp))
+      return emitSilenceableError() << "not elementwise";
+    return DiagnosedSilenceableFailure::success();
+  }
   if (std::optional<ArrayAttr> contractionOps = getContraction()) {
     Block &body = linalgOp->getRegion(0).front();
     std::string message;
@@ -209,13 +215,14 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
 
 LogicalResult transform::MatchStructuredBodyOp::verify() {
   int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
-                       getContraction().has_value();
+                       getElementwise() + getContraction().has_value();
 
   if (numOptions > 1) {
     std::string attributeNames;
     llvm::raw_string_ostream os(attributeNames);
     llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
                                                getPassthroughAttrName(),
+                                               getElementwiseAttrName(),
                                                getContractionAttrName()},
                           os);
     return emitOpError() << "only one of {" << os.str() << "} is allowed";
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index a7353a4c38881e..24c7bdd9e1050e 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -180,6 +180,63 @@ module attributes { transform.with_named_sequence } {
 
 // -----
 
+module attributes { transform.with_named_sequence } {
+  transform.named_sequence @print_elementwise(%arg0: !transform.any_op {transform.readonly}) {
+    transform.debug.emit_remark_at %arg0, "elementwise" : !transform.any_op
+    transform.yield
+  }
+
+  transform.named_sequence @match_structured_body_elementwise(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+    %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+    ^bb0(%arg1: !transform.any_op):
+      transform.match.structured.body %arg1 { elementwise } : !transform.any_op
+      transform.match.structured.yield %arg1 : !transform.any_op
+    }
+    transform.yield %0 : !transform.any_op
+  }
+
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
+    transform.foreach_match in %arg0
+        @match_structured_body_elementwise -> @print_elementwise
+        : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+
+  func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
+    %cst0 = arith.constant 0.0 : f32
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    // expected-remark @below {{elementwise}}
+    %fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32>
+    // expected-remark @below {{elementwise}}
+    %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>)
+    %non_elementwise = linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]}
+      ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
+        ^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
+          %0 = arith.addf %arg0, %arg1 : f32
+          %1 = tensor.dim %add, %c0 : tensor<2xf32>
+          %2 = arith.subi %1, %c1 : index
+          %3 = tensor.extract %add[%2] : tensor<2xf32>
+          %4 = arith.mulf %0, %3 : f32
+          linalg.yield %4 : f32
+      } -> tensor<2x3xf32>
+    // expected-remark @below {{elementwise}}
+    %add_bcast = linalg.generic
+      {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+       iterator_types = ["parallel", "parallel"]}
+      ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
+        ^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
+          %0 = arith.addf %arg0, %arg1 : f32
+          linalg.yield %0 : f32
+      } -> tensor<2x3xf32>
+    return %add, %add_bcast, %non_elementwise : tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>
+  }
+}
+
+// -----
+
 module attributes { transform.with_named_sequence } {
   transform.named_sequence @print_reduction(%arg0: !transform.any_op {transform.readonly}) {
     transform.debug.emit_remark_at %arg0, "reduction" : !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
index ec99e205090c4c..9ff430a3503606 100644
--- a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
@@ -64,7 +64,7 @@ transform.sequence failures(suppress) {
 ^bb0(%arg0: !transform.any_op):
   transform.match.structured %arg0 : !transform.any_op {
   ^bb1(%arg1: !transform.any_op):
-    // expected-error @below {{only one of {"reduction_position", "passthrough", "contraction"} is allowed}}
+    // expected-error @below {{only one of {"reduction_position", "passthrough", "elementwise", "contraction"} is allowed}}
     transform.match.structured.body %arg1 { passthrough, reduction_position = 0 } : !transform.any_op
     transform.match.structured.yield
   }

@chelini chelini requested a review from ftynse January 26, 2024 21:23
@srcarroll
Copy link
Contributor Author

anyone know what's up with the buildkite failure?

@ftynse
Copy link
Member

ftynse commented Jan 31, 2024

Looks like whatever Windows container it runs on is missing a numpy installation. Not correlated with this patch in any way.

@ftynse ftynse merged commit 488f88b into llvm:main Jan 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants