-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] Exclude non-convolutional ops from isaConvolutionOpInterface #102087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: None (yifeizh2) ChangesEnhance convolution op judgement logic Full diff: https://github.com/llvm/llvm-project/pull/102087.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 6ee1810c2ff2b..41143e0a5e347 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -762,7 +762,8 @@ enum class MatchConvolutionResult {
NotProjectedPermutations,
NonConvolutionLoop,
OutputDimsNotParallel,
- NonOutputDimNotReduction
+ NonOutputDimNotReduction,
+ NoValidConvolvedDim
};
} // namespace mlir::linalg::detail
@@ -810,6 +811,8 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
// - Depth multiplier : unconvolved in input, present in output, present in
// filter.
llvm::SmallDenseSet<int64_t> allLoopDims;
+ bool hasOutputImageDim = false;
+ bool hasFilterLoopDim = false;
for (auto outputExpr : indexingMaps.back().getResults()) {
int64_t outputDim = cast<AffineDimExpr>(outputExpr).getPosition();
if (inputExprWalker.unConvolvedDims.count(outputDim) &&
@@ -825,6 +828,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
// Output image Loop dimension.
if (iteratorTypes[outputDim] != utils::IteratorType::parallel)
return MatchConvolutionResult::OutputDimsNotParallel;
+ hasOutputImageDim = true;
allLoopDims.insert(outputDim);
continue;
}
@@ -862,6 +866,7 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
return MatchConvolutionResult::NonOutputDimNotReduction;
if (allLoopDims.count(filterDim))
return MatchConvolutionResult::NonConvolutionLoop;
+ hasFilterLoopDim = true;
allLoopDims.insert(filterDim);
continue;
}
@@ -886,6 +891,9 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
if (allLoopDims.size() != linalgOp.getNumLoops())
return MatchConvolutionResult::NonConvolutionLoop;
+ if (!hasOutputImageDim || !hasFilterLoopDim)
+ return MatchConvolutionResult::NoValidConvolvedDim;
+
if (dimensions) {
FailureOr<ConvolutionDimensions> res =
inferConvolutionDimsImpl(linalgOp, inputExprWalker,
|
If we allow |
@qedawkins Sorry for bothering. I wonder why we allowed empty convolved dims in the initial design. |
This has bothered me as well, especially because this helper is inconsistent with which named ops carry the convolution interface. I think it would be best to just expose |
Hi. I updated the logic and hope to gain more suggestions. Currently the logic for |
Thanks for working on this! For those of us a bit less familiar with the logic being updated here ...
... what's broken and how does this PR fix it? :) |
Currently, the The |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the description you just gave to the PR description? Additionally please add a short explanation in the description for how downstream users can preserve the current behavior with the given API.
It is good to be clear about the motivation behind the change in the description: https://llvm.org/docs/DeveloperPolicy.html#commit-messages
5328b74
to
b254151
Compare
I provided a simplified explanation to both PR description section and commit message. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, can you change the PR title to something more descriptive like
[mlir][linalg] Exclude non-convolutional ops from isaConvolutionOpInterface
like @banach-space was looking for?
Thanks, that's exactly what I was after 🙏🏻 Some testing would be nice, but that would have to be a unit test and there's just none for Linalg atm :( I'm leaving that as a nice-to-have. This change makes sense to :) |
I was trying to think about what kind of test could be added, but this is pretty much just exposing an option that was already there. If we wanted to make it essentially NFC we could change the default back. We could also add a transform dialect op for |
I had planned to add some unit tests for this logic change, but I failed to find existing unit tests for its counterparts (e.g. As for testing with transform dialect op, currently we have |
I am not sure how useful such an operation would be given that |
I agree. The functionality of |
@qedawkins Since there is no more review comments, can I merge this PR by myself? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me
Currently, `isaConvolutionOpInterface` returns false positive for linalg binary elementwise ops, because the function's underlying logic does not require the linalg op to have convolved dims. We avoid such false positive by further checking the non-emptyness of convolved dims.
b254151
to
4b80eb7
Compare
This patch seems to cause
Would you mind taking a look? Thanks! |
This patch fixes: mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp:906:11: error: enumeration value 'EmptyConvolvedDims' not handled in switch [-Werror,-Wswitch] with a workaround. I've notified the author of the new enum value in #102087.
I've submitted 2ef3dcf as a workaround. I'd appreciate if you could fix the message part of it. Thanks! |
Sorry for my negligence, and thanks for your providing the workaround for it. I will fix the message part soon! |
Enhance
isaConvolutionOpInterface
logic.Currently,
isaConvolutionOpInterface
returns false positive for linalg binary elementwise ops, because the function's underlying logic does not require the input linalg op to have convolved dims. We avoid such false positive by further checking the non-emptyness of convolved dims.