-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Transform] Add attribute in MatchOp to filter by operand type #67994
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
66efa5f
to
c3b9f85
Compare
It seems like build failure is related to 39fec54. |
Ping |
This comment was marked as outdated.
This comment was marked as outdated.
(Rewritten) A negative testcase will be nice, say to check if a match fail when operand type does NOT match. |
Thanks for your feedback, I just added a second test case to cover the negative case as well. |
@@ -535,6 +535,7 @@ def MatchOp : Op<Transform_Dialect, "structured.match", | |||
- attribute: the matched op must have all specified attributes (with their | |||
specified values). | |||
- filter_result_type: the matched op must return exactly this one type. | |||
- filter_operand_type: all the operands of the matched op must must be of this type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add line break
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved, thanks for the tip
@@ -1180,6 +1180,15 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter, | |||
return; | |||
} | |||
|
|||
if (getFilterOperandType().has_value()) { | |||
Type t = getFilterOperandType().value(); | |||
for (auto type : op->getOperandTypes()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can use !llvm::all_of
. Also, drop trivial braces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved, although I personally feel that with llvm::all_of
it is less readable
|
||
// ----- | ||
|
||
func.func @by_operand_type_negative_match() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate test case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved, thanks
|
||
// ----- | ||
|
||
func.func @by_operand_type_negative_match() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A single test case for "match" and "no match" should be sufficient. E.g., add two ops "test.foo"() : () -> (i32)
and "test.foo"() : () -> (f32)
. One will be matched, the one other one won't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved, thanks for the tip
@@ -556,7 +558,8 @@ def MatchOp : Op<Transform_Dialect, "structured.match", | |||
OptionalAttr<StrArrayAttr>:$ops, | |||
OptionalAttr<MatchInterfaceEnum>:$interface, | |||
OptionalAttr<DictionaryAttr>:$op_attrs, | |||
OptionalAttr<TypeAttr>:$filter_result_type); | |||
OptionalAttr<TypeAttr>:$filter_result_type, | |||
OptionalAttr<TypeAttr>:$filter_operand_type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems too restrictive: operations often have multiple operand types, can we turn this into a list ?
Bonus points for doing the same for the result_types (although single result ops are much more common)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for suggesting this. Actually, I considered doing it, but I didn't because it was unclear to me how the list should work.
After considering it again, I have added a new commit which adds supports for this. The list expects to have a length equal to the number of operands, and the match.op
will only succeed if the list of operand types match exactly the operand types in the target op (following the same order).
Besides, I was not sure how to report the error when the length of the list is incorrect, because the error cannot be reported inside the lambda. Thus, I have created an auxiliary variable wrong_operand_filter
that is checked after the lambda, but I'm open to change it if there is a more elegant way.
PS: For the result_types
I agree, but I feel like it is more appropriate for a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay.
@@ -1141,6 +1141,7 @@ transform::MatchOp::apply(transform::TransformRewriter &rewriter, | |||
} | |||
|
|||
SmallVector<Operation *> res; | |||
bool wrong_operand_filter = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
camel case: wrongOperandFilter
. Nit: Actually, I would call the variable incorrectNumOperandTypes
or sth. like that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved, thanks
} | ||
|
||
for (auto const &it : | ||
llvm::zip(getFilterOperandTypes().value(), operandTypes)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Use llvm::zip_equal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solved, thanks for the tip
return; | ||
} | ||
|
||
for (auto const &it : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can write for (auto [attr, operandType] : llvm::zip_equal ...
to avoid std::get
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to know, thanks
llvm::zip(getFilterOperandTypes().value(), operandTypes)) { | ||
auto attr = dyn_cast<mlir::TypeAttr>(std::get<0>(it)); | ||
Type type = attr.getValue().cast<::mlir::Type>(); | ||
Type t = getElementTypeOrSelf(std::get<1>(it)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there should be no getElementTypeOrSelf
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has been removed in the latest commit
if (getFilterOperandTypes().has_value()) { | ||
mlir::ArrayAttr types = getFilterOperandTypes().value(); | ||
auto operandTypes = op->getOperandTypes(); | ||
if (types.size() != operandTypes.size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is different from the op documentation:
all the operands of the matched op must must be of
this type
Does not handle the case where an op has multiple operands but only one type is specified in the transform op. (If you did not intend to support such cases, update the documentation. It sounds like if there's 1 type but the op has multiple operands, then all operands must have that type.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. Actually I want it to support that case so I have fixed the implementation, it should be fixed now
…than one type. Small fixes. Add more testcases
Thanks for your time. Apart from the fixes, I have added a few more test cases to show different scenarios:
|
for (auto [attr, operandType] : | ||
llvm::zip_equal(getFilterOperandTypes().value(), operandTypes)) { | ||
auto typeattr = dyn_cast<mlir::TypeAttr>(attr); | ||
Type type = typeattr.getValue().cast<::mlir::Type>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the cast necessary? Also please write “cast<…>(x)”. the other syntax is deprecated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think it's needed because attr
is of type Attribute
but we need it to be TypeAttr
. I have changed the cast to the one you suggested, thanks for the tip!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
I don't see the merge button anywhere @nicolasvasilache @matthias-springer, and github says that "Only those with write access to this repository can merge pull requests", so it seems like I don't have access. Can you grant me write access so I can merge it myself? I'd like to contribute more patches in the future. Thanks! |
Please follow https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access to request commit access. |
…nd type (#67994)" This reverts commit c439913. Test fails https://lab.llvm.org/buildbot/#/builders/272/builds/2757
Thanks for the revert. For what it's worth. The test was also failing for me locally. Maybe the cause is that f30a402 tested against an outdated |
Thanks for the revert. @rikhuijzer I think it's because the MLIR tests have changed, and now they require to have a named transform sequence. My bad, I'll fix that shortly. |
…nd type (#67994)" Test was failing due to a different transform sequence declaration (transform sequence were used, while now it should be named transform sequence). Test is now fixed.
This patchs adds the
filter_operand_type
attribute to transform::MatchOp. With this attribute, the MatchOp will only match an op when the type of all the operands are equal to the specified type.