-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][transform] Add elementwise criteria to match.structured.body
#79626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Is it ready for review? Is marked as "draft". |
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: None (srcarroll) ChangesAs far as I am aware, there is no simple way to match on elementwise ops. I propose to add an Full diff: https://github.com/llvm/llvm-project/pull/79626.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
index 162dd05f93030f..dfeb8ae5d5ddbc 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td
@@ -106,6 +106,9 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
* `passthrough`: the body of the structured payload op only forwards
inputs to the outputs (copy or broadcast).
+ * `elementwise`: the body of the structured payload op represents an
+ elementwise operation.
+
* `contraction`: the body of the structured payload op is a contraction
of the form `<red>(<elem>(bbarg0, bbarg1), bbarg2)` where `<elem>` and
`<red>` are binary operations whose names are specified in the attribute
@@ -123,6 +126,7 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
let arguments = (ins TransformHandleTypeInterface:$operand_handle,
OptionalAttr<I64Attr>:$reduction_position,
UnitAttr:$passthrough,
+ UnitAttr:$elementwise,
OptionalAttr<StrArrayAttr>:$contraction);
let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle)";
let extraClassDeclaration = SingleOpMatcher.extraDeclaration;
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 115da4b90e063a..fb18886c16b16d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
+#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/FunctionImplementation.h"
@@ -187,6 +188,11 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
}
return DiagnosedSilenceableFailure::success();
}
+ if (getElementwise()) {
+ if (!isElementwise(linalgOp))
+ return emitSilenceableError() << "not elementwise";
+ return DiagnosedSilenceableFailure::success();
+ }
if (std::optional<ArrayAttr> contractionOps = getContraction()) {
Block &body = linalgOp->getRegion(0).front();
std::string message;
@@ -209,13 +215,14 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
LogicalResult transform::MatchStructuredBodyOp::verify() {
int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
- getContraction().has_value();
+ getElementwise() + getContraction().has_value();
if (numOptions > 1) {
std::string attributeNames;
llvm::raw_string_ostream os(attributeNames);
llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
getPassthroughAttrName(),
+ getElementwiseAttrName(),
getContractionAttrName()},
os);
return emitOpError() << "only one of {" << os.str() << "} is allowed";
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index a7353a4c38881e..24c7bdd9e1050e 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -180,6 +180,63 @@ module attributes { transform.with_named_sequence } {
// -----
+module attributes { transform.with_named_sequence } {
+ transform.named_sequence @print_elementwise(%arg0: !transform.any_op {transform.readonly}) {
+ transform.debug.emit_remark_at %arg0, "elementwise" : !transform.any_op
+ transform.yield
+ }
+
+ transform.named_sequence @match_structured_body_elementwise(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
+ %0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
+ ^bb0(%arg1: !transform.any_op):
+ transform.match.structured.body %arg1 { elementwise } : !transform.any_op
+ transform.match.structured.yield %arg1 : !transform.any_op
+ }
+ transform.yield %0 : !transform.any_op
+ }
+
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
+ transform.foreach_match in %arg0
+ @match_structured_body_elementwise -> @print_elementwise
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+
+ func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
+ %cst0 = arith.constant 0.0 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ // expected-remark @below {{elementwise}}
+ %fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32>
+ // expected-remark @below {{elementwise}}
+ %add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>)
+ %non_elementwise = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
+ %0 = arith.addf %arg0, %arg1 : f32
+ %1 = tensor.dim %add, %c0 : tensor<2xf32>
+ %2 = arith.subi %1, %c1 : index
+ %3 = tensor.extract %add[%2] : tensor<2xf32>
+ %4 = arith.mulf %0, %3 : f32
+ linalg.yield %4 : f32
+ } -> tensor<2x3xf32>
+ // expected-remark @below {{elementwise}}
+ %add_bcast = linalg.generic
+ {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]}
+ ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
+ ^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
+ %0 = arith.addf %arg0, %arg1 : f32
+ linalg.yield %0 : f32
+ } -> tensor<2x3xf32>
+ return %add, %add_bcast, %non_elementwise : tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>
+ }
+}
+
+// -----
+
module attributes { transform.with_named_sequence } {
transform.named_sequence @print_reduction(%arg0: !transform.any_op {transform.readonly}) {
transform.debug.emit_remark_at %arg0, "reduction" : !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
index ec99e205090c4c..9ff430a3503606 100644
--- a/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-invalid.mlir
@@ -64,7 +64,7 @@ transform.sequence failures(suppress) {
^bb0(%arg0: !transform.any_op):
transform.match.structured %arg0 : !transform.any_op {
^bb1(%arg1: !transform.any_op):
- // expected-error @below {{only one of {"reduction_position", "passthrough", "contraction"} is allowed}}
+ // expected-error @below {{only one of {"reduction_position", "passthrough", "elementwise", "contraction"} is allowed}}
transform.match.structured.body %arg1 { passthrough, reduction_position = 0 } : !transform.any_op
transform.match.structured.yield
}
|
anyone know what's up with the buildkite failure? |
Looks like whatever Windows container it runs on is missing a numpy installation. Not correlated with this patch in any way. |
As far as I am aware, there is no simple way to match on elementwise ops. I propose to add an
elementwise
criteria to thematch.structured.body
op. Although my only hesitation is that elementwise is not only determined by the body, but also the indexing maps. So if others find this too awkward, I can implement a separate match op instead.