Skip to content

[MLIR][Transform] Add attribute in MatchOp to filter by operand type #67994

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,11 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
- attribute: the matched op must have all specified attributes (with their
specified values).
- filter_result_type: the matched op must return exactly this one type.
- filter_operand_types: all the operands of the matched op must must be of
this type. If more than a type is specified, then the length of the list
must be equal to the number of operands in the matched op, and the match
will succeed only if the operand types match all the types in the list
in the order in which they are specified.

Note: Only ops that satisfy all specified constraints are matched.

Expand All @@ -556,7 +561,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
OptionalAttr<StrArrayAttr>:$ops,
OptionalAttr<MatchInterfaceEnum>:$interface,
OptionalAttr<DictionaryAttr>:$op_attrs,
OptionalAttr<TypeAttr>:$filter_result_type);
OptionalAttr<TypeAttr>:$filter_result_type,
OptionalAttr<TypeArrayAttr>:$filter_operand_types);
// TODO: variadic results when needed.
let results = (outs TransformHandleTypeInterface:$results);

Expand All @@ -570,6 +576,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match",
(`interface` `{` $interface^ `}`)?
(`attributes` $op_attrs^)?
(`filter_result_type` `=` $filter_result_type^)?
(`filter_operand_types` `=` $filter_operand_types^)?
`in` $target attr-dict
`:` functional-type($target, results)
}];
Expand Down
36 changes: 36 additions & 0 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
}

SmallVector<Operation *> res;
bool incorrectNumOperandTypes = false;
auto matchFun = [&](Operation *op) {
if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
return;
Expand Down Expand Up @@ -1180,12 +1181,47 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter,
return;
}

if (getFilterOperandTypes().has_value()) {
mlir::ArrayAttr types = getFilterOperandTypes().value();
auto operandTypes = op->getOperandTypes();

if (types.size() == 1) {
// All the operands must must be equal to the specified type
auto typeattr =
dyn_cast<mlir::TypeAttr>(getFilterOperandTypes().value()[0]);
Type t = typeattr.getValue().cast<::mlir::Type>();
if (!llvm::all_of(op->getOperandTypes(),
[&](Type operandType) { return operandType == t; }))
return;
} else {
// The operand types must match all the types in the list (in the same
// order in with they are specified)
if (types.size() != operandTypes.size()) {
incorrectNumOperandTypes = true;
return;
}

for (auto [attr, operandType] :
llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) {
auto typeattr = cast<mlir::TypeAttr>(attr);
Type type = typeattr.getValue().cast<::mlir::Type>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the cast necessary? Also please write “cast<…>(x)”. the other syntax is deprecated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think it's needed because attr is of type Attribute but we need it to be TypeAttr. I have changed the cast to the one you suggested, thanks for the tip!


if (type != operandType)
return;
}
}
}

// All constraints are satisfied.
res.push_back(op);
return;
};

(*payloadOps.begin())->walk(matchFun);
if (incorrectNumOperandTypes)
return emitDefiniteFailure("If filter_operand_types contains more than a "
"type, then it must contain as much types as "
"the number of operands in the target ops");
results.set(cast<OpResult>(getResult()), res);
return DiagnosedSilenceableFailure::success();
}
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-match.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,46 @@ transform.sequence failures(propagate) {

// -----

func.func @by_operand_type() {
%c2 = arith.constant 2.0: f32
%v = arith.constant 8: i32
%r1 = math.fpowi %c2, %v : f32, i32
// expected-remark @below {{matched op name}}
%r2 = arith.addf %c2, %c2 : f32
// expected-remark @below {{matched op name}}
%r3 = arith.fptoui %r2 : f32 to i32
return
}

transform.sequence failures(propagate) {
^bb1(%arg1: !transform.any_op):
%match_name1 = transform.structured.match
ops{["arith.fptoui"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
transform.test_print_remark_at_operand %match_name1, "matched op name" : !transform.any_op
transform.test_consume_operand %match_name1 : !transform.any_op

%match_name2 = transform.structured.match
ops{["arith.addf"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
transform.test_print_remark_at_operand %match_name2, "matched op name" : !transform.any_op
transform.test_consume_operand %match_name2 : !transform.any_op

%no_match_name1 = transform.structured.match
ops{["arith.fptoui"]} filter_operand_types = [i32] in %arg1 : (!transform.any_op) -> !transform.any_op
transform.test_print_remark_at_operand %no_match_name1, "should not match" : !transform.any_op
transform.test_consume_operand %no_match_name1 : !transform.any_op

%no_match_name2 = transform.structured.match
ops{["math.fpowi"]} filter_operand_types = [f32] in %arg1 : (!transform.any_op) -> !transform.any_op
transform.test_print_remark_at_operand %no_match_name2, "should not match" : !transform.any_op
transform.test_consume_operand %no_match_name2 : !transform.any_op

// expected-error @+1 {{If filter_operand_types contains more than a type, then it must contain as much types as the number of operands in the target ops}}
%failure_match = transform.structured.match
ops{["arith.fptoui"]} filter_operand_types = [i32, i32] in %arg1 : (!transform.any_op) -> !transform.any_op
}

// -----

func.func @foo(%a: tensor<4x4xf32>, %b: tensor<4x4xf32>, %c: tensor<4x4xf32>) {
%c0 = arith.constant 0.0 : f32
// expected-remark @below {{tileable}}
Expand Down