Skip to content

Commit c439913

Browse files
[MLIR][Transform] Add attribute in MatchOp to filter by operand type (#67994)
This patchs adds the `filter_operand_types` attribute to transform::MatchOp, allowing to filter ops depending on their operand types.
1 parent 39ac5ee commit c439913

File tree

3 files changed

+84
-1
lines changed

3 files changed

+84
-1
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,11 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
574574
- attribute: the matched op must have all specified attributes (with their
575575
specified values).
576576
- filter_result_type: the matched op must return exactly this one type.
577+
- filter_operand_types: all the operands of the matched op must must be of
578+
this type. If more than a type is specified, then the length of the list
579+
must be equal to the number of operands in the matched op, and the match
580+
will succeed only if the operand types match all the types in the list
581+
in the order in which they are specified.
577582

578583
Note: Only ops that satisfy all specified constraints are matched.
579584

@@ -595,7 +600,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
595600
OptionalAttr<StrArrayAttr>:$ops,
596601
OptionalAttr<MatchInterfaceEnum>:$interface,
597602
OptionalAttr<DictionaryAttr>:$op_attrs,
598-
OptionalAttr<TypeAttr>:$filter_result_type);
603+
OptionalAttr<TypeAttr>:$filter_result_type,
604+
OptionalAttr<TypeArrayAttr>:$filter_operand_types);
599605
// TODO: variadic results when needed.
600606
let results = (outs TransformHandleTypeInterface:$results);
601607

@@ -609,6 +615,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
609615
(`interface` `{` $interface^ `}`)?
610616
(`attributes` $op_attrs^)?
611617
(`filter_result_type` `=` $filter_result_type^)?
618+
(`filter_operand_types` `=` $filter_operand_types^)?
612619
`in` $target attr-dict
613620
`:` functional-type($target, results)
614621
}];

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
11711171
}
11721172

11731173
SmallVector<Operation *> res;
1174+
bool incorrectNumOperandTypes = false;
11741175
auto matchFun = [&](Operation *op) {
11751176
if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
11761177
return;
@@ -1210,12 +1211,47 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
12101211
return;
12111212
}
12121213

1214+
if (getFilterOperandTypes().has_value()) {
1215+
mlir::ArrayAttr types = getFilterOperandTypes().value();
1216+
auto operandTypes = op->getOperandTypes();
1217+
1218+
if (types.size() == 1) {
1219+
// All the operands must must be equal to the specified type
1220+
auto typeattr =
1221+
dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
1222+
Type t = typeattr.getValue().cast<::mlir::Type>();
1223+
if (!llvm::all_of(op->getOperandTypes(),
1224+
[&](Type operandType) { return operandType == t; }))
1225+
return;
1226+
} else {
1227+
// The operand types must match all the types in the list (in the same
1228+
// order in with they are specified)
1229+
if (types.size() != operandTypes.size()) {
1230+
incorrectNumOperandTypes = true;
1231+
return;
1232+
}
1233+
1234+
for (auto [attr, operandType] :
1235+
llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
1236+
auto typeattr = cast<mlir::TypeAttr>(attr);
1237+
Type type = typeattr.getValue().cast<::mlir::Type>();
1238+
1239+
if (type != operandType)
1240+
return;
1241+
}
1242+
}
1243+
}
1244+
12131245
// All constraints are satisfied.
12141246
res.push_back(op);
12151247
return;
12161248
};
12171249

12181250
(*payloadOps.begin())->walk(matchFun);
1251+
if (incorrectNumOperandTypes)
1252+
return emitDefiniteFailure("If filter_operand_types contains more than a "
1253+
"type, then it must contain as much types as "
1254+
"the number of operands in the target ops");
12191255
results.set(cast<OpResult>(getResult()), res);
12201256
return DiagnosedSilenceableFailure::success();
12211257
}

mlir/test/Dialect/Linalg/transform-op-match.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,46 @@ module attributes {transform.with_named_sequence} {
4343

4444
// -----
4545

46+
func.func @by_operand_type() {
47+
%c2 = arith.constant 2.0: f32
48+
%v = arith.constant 8: i32
49+
%r1 = math.fpowi %c2, %v : f32, i32
50+
// expected-remark @below {{matched op name}}
51+
%r2 = arith.addf %c2, %c2 : f32
52+
// expected-remark @below {{matched op name}}
53+
%r3 = arith.fptoui %r2 : f32 to i32
54+
return
55+
}
56+
57+
transform.sequence failures(propagate) {
58+
^bb1(%arg1: !transform.any_op):
59+
%match_name1 = transform.structured.match
60+
ops{["arith.fptoui"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
61+
transform.test_print_remark_at_operand %match_name1, "matched op name" : !transform.any_op
62+
transform.test_consume_operand %match_name1 : !transform.any_op
63+
64+
%match_name2 = transform.structured.match
65+
ops{["arith.addf"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
66+
transform.test_print_remark_at_operand %match_name2, "matched op name" : !transform.any_op
67+
transform.test_consume_operand %match_name2 : !transform.any_op
68+
69+
%no_match_name1 = transform.structured.match
70+
ops{["arith.fptoui"]} filter_operand_types = [i32] in %arg1 : (!transform.any_op) -> !transform.any_op
71+
transform.test_print_remark_at_operand %no_match_name1, "should not match" : !transform.any_op
72+
transform.test_consume_operand %no_match_name1 : !transform.any_op
73+
74+
%no_match_name2 = transform.structured.match
75+
ops{["math.fpowi"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
76+
transform.test_print_remark_at_operand %no_match_name2, "should not match" : !transform.any_op
77+
transform.test_consume_operand %no_match_name2 : !transform.any_op
78+
79+
// expected-error @+1 {{If filter_operand_types contains more than a type, then it must contain as much types as the number of operands in the target ops}}
80+
%failure_match = transform.structured.match
81+
ops{["arith.fptoui"]} filter_operand_types = [i32, i32] in %arg1 : (!transform.any_op) -> !transform.any_op
82+
}
83+
84+
// -----
85+
4686
func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
4787
%c0 = arith.constant 0.0 : f32
4888
// expected-remark @below {{tileable}}

0 commit comments

Comments
 (0)