Skip to content

Commit 488f88b

Browse files
authored
[mlir][transform] Add elementwise criteria to match.structured.body (llvm#79626)
As far as I am aware, there is no simple way to match on elementwise ops. I propose to add an `elementwise` criteria to the `match.structured.body` op. Although my only hesitation is that elementwise is not only determined by the body, but also the indexing maps. So if others find this too awkward, I can implement a separate match op instead.
1 parent 89f87c3 commit 488f88b

File tree

4 files changed

+70
-2
lines changed

4 files changed

+70
-2
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
106106
* `passthrough`: the body of the structured payload op only forwards
107107
inputs to the outputs (copy or broadcast).
108108

109+
* `elementwise`: the body of the structured payload op represents an
110+
elementwise operation.
111+
109112
* `contraction`: the body of the structured payload op is a contraction
110113
of the form `<red>(<elem>(bbarg0, bbarg1), bbarg2)` where `<elem>` and
111114
`<red>` are binary operations whose names are specified in the attribute
@@ -123,6 +126,7 @@ def MatchStructuredBodyOp : Op<Transform_Dialect, "match.structured.body", [
123126
let arguments = (ins TransformHandleTypeInterface:$operand_handle,
124127
OptionalAttr<I64Attr>:$reduction_position,
125128
UnitAttr:$passthrough,
129+
UnitAttr:$elementwise,
126130
OptionalAttr<StrArrayAttr>:$contraction);
127131
let assemblyFormat = "$operand_handle attr-dict `:` type($operand_handle)";
128132
let extraClassDeclaration = SingleOpMatcher.extraDeclaration;

mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1212
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1313
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
14+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1415
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
1516
#include "mlir/IR/BuiltinAttributes.h"
1617
#include "mlir/Interfaces/FunctionImplementation.h"
@@ -187,6 +188,11 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
187188
}
188189
return DiagnosedSilenceableFailure::success();
189190
}
191+
if (getElementwise()) {
192+
if (!isElementwise(linalgOp))
193+
return emitSilenceableError() << "not elementwise";
194+
return DiagnosedSilenceableFailure::success();
195+
}
190196
if (std::optional<ArrayAttr> contractionOps = getContraction()) {
191197
Block &body = linalgOp->getRegion(0).front();
192198
std::string message;
@@ -209,13 +215,14 @@ DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
209215

210216
LogicalResult transform::MatchStructuredBodyOp::verify() {
211217
int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
212-
getContraction().has_value();
218+
getElementwise() + getContraction().has_value();
213219

214220
if (numOptions > 1) {
215221
std::string attributeNames;
216222
llvm::raw_string_ostream os(attributeNames);
217223
llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
218224
getPassthroughAttrName(),
225+
getElementwiseAttrName(),
219226
getContractionAttrName()},
220227
os);
221228
return emitOpError() << "only one of {" << os.str() << "} is allowed";

mlir/test/Dialect/Linalg/match-ops-interpreter.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,63 @@ module attributes { transform.with_named_sequence } {
180180

181181
// -----
182182

183+
module attributes { transform.with_named_sequence } {
184+
transform.named_sequence @print_elementwise(%arg0: !transform.any_op {transform.readonly}) {
185+
transform.debug.emit_remark_at %arg0, "elementwise" : !transform.any_op
186+
transform.yield
187+
}
188+
189+
transform.named_sequence @match_structured_body_elementwise(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
190+
%0 = transform.match.structured failures(propagate) %arg0 : (!transform.any_op) -> !transform.any_op {
191+
^bb0(%arg1: !transform.any_op):
192+
transform.match.structured.body %arg1 { elementwise } : !transform.any_op
193+
transform.match.structured.yield %arg1 : !transform.any_op
194+
}
195+
transform.yield %0 : !transform.any_op
196+
}
197+
198+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.consumed}) {
199+
transform.foreach_match in %arg0
200+
@match_structured_body_elementwise -> @print_elementwise
201+
: (!transform.any_op) -> !transform.any_op
202+
transform.yield
203+
}
204+
205+
func.func @payload(%in1: tensor<2xf32>, %in2: tensor<2xf32>, %in3: tensor<2x3xf32>, %out: tensor<2xf32>, %out2: tensor<2x3xf32>) -> (tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>) attributes { transform.target_tag = "start_here" } {
206+
%cst0 = arith.constant 0.0 : f32
207+
%c0 = arith.constant 0 : index
208+
%c1 = arith.constant 1 : index
209+
// expected-remark @below {{elementwise}}
210+
%fill = linalg.fill ins(%cst0: f32) outs(%out: tensor<2xf32>) -> tensor<2xf32>
211+
// expected-remark @below {{elementwise}}
212+
%add = linalg.map {arith.addf} ins(%in1, %in2: tensor<2xf32>, tensor<2xf32>) outs(%fill: tensor<2xf32>)
213+
%non_elementwise = linalg.generic
214+
{indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
215+
iterator_types = ["parallel", "parallel"]}
216+
ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
217+
^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
218+
%0 = arith.addf %arg0, %arg1 : f32
219+
%1 = tensor.dim %add, %c0 : tensor<2xf32>
220+
%2 = arith.subi %1, %c1 : index
221+
%3 = tensor.extract %add[%2] : tensor<2xf32>
222+
%4 = arith.mulf %0, %3 : f32
223+
linalg.yield %4 : f32
224+
} -> tensor<2x3xf32>
225+
// expected-remark @below {{elementwise}}
226+
%add_bcast = linalg.generic
227+
{indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
228+
iterator_types = ["parallel", "parallel"]}
229+
ins(%in1, %in3: tensor<2xf32>, tensor<2x3xf32>) outs(%out2: tensor<2x3xf32>) {
230+
^bb0(%arg0: f32, %arg1: f32, %arg3: f32):
231+
%0 = arith.addf %arg0, %arg1 : f32
232+
linalg.yield %0 : f32
233+
} -> tensor<2x3xf32>
234+
return %add, %add_bcast, %non_elementwise : tensor<2xf32>, tensor<2x3xf32>, tensor<2x3xf32>
235+
}
236+
}
237+
238+
// -----
239+
183240
module attributes { transform.with_named_sequence } {
184241
transform.named_sequence @print_reduction(%arg0: !transform.any_op {transform.readonly}) {
185242
transform.debug.emit_remark_at %arg0, "reduction" : !transform.any_op

mlir/test/Dialect/Linalg/match-ops-invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ transform.sequence failures(suppress) {
6464
^bb0(%arg0: !transform.any_op):
6565
transform.match.structured %arg0 : !transform.any_op {
6666
^bb1(%arg1: !transform.any_op):
67-
// expected-error @below {{only one of {"reduction_position", "passthrough", "contraction"} is allowed}}
67+
// expected-error @below {{only one of {"reduction_position", "passthrough", "elementwise", "contraction"} is allowed}}
6868
transform.match.structured.body %arg1 { passthrough, reduction_position = 0 } : !transform.any_op
6969
transform.match.structured.yield
7070
}

0 commit comments

Comments
 (0)