Skip to content

Commit 4b80eb7

Browse files
committed
[mlir][linalg] Fix isaConvolutionOpInterface logic
Currently, `isaConvolutionOpInterface` returns false positive for linalg binary elementwise ops, because the function's underlying logic does not require the linalg op to have convolved dims. We avoid such false positive by further checking the non-emptyness of convolved dims.
1 parent 2b66417 commit 4b80eb7

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,12 @@ struct ConvolutionDimensions {
110110
FailureOr<ConvolutionDimensions> inferConvolutionDims(LinalgOp linalgOp);
111111

112112
/// Checks whether `linalgOp` conforms to ConvolutionOpInterface.
113+
/// By default, we require the `linalgOp` to have non-empty convolved dims
114+
/// (implicitly non-empty `output_image` and `filter_loop`).
115+
/// Users can loosen the constraint by setting `allowEmptyConvolvedDims` to true
113116
// TODO: embed within `isa<ConvolutionOpInterface>` if possible / natural.
114-
bool isaConvolutionOpInterface(LinalgOp linalgOp);
117+
bool isaConvolutionOpInterface(LinalgOp linalgOp,
118+
bool allowEmptyConvolvedDims = false);
115119

116120
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
117121
bool isaCopyOpInterface(LinalgOp linalgOp);
@@ -175,9 +179,12 @@ enum class MatchConvolutionResult;
175179
/// Checks whether `op` conforms to ConvolutionOpInterface and populates
176180
/// `dimensions` with indexes of the different kinds of dimensions when
177181
/// present.
182+
/// If `allowEmptyConvolvedDims` is not set, we further checks whether the `op`
183+
/// contains convolved dims.
178184
MatchConvolutionResult
179185
isConvolutionInterfaceImpl(Operation *op,
180-
ConvolutionDimensions *dimensions = nullptr);
186+
ConvolutionDimensions *dimensions = nullptr,
187+
bool allowEmptyConvolvedDims = false);
181188

182189
/// Returns the error message corresponding to the convolution checking return
183190
/// code.

mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -762,13 +762,15 @@ enum class MatchConvolutionResult {
762762
NotProjectedPermutations,
763763
NonConvolutionLoop,
764764
OutputDimsNotParallel,
765-
NonOutputDimNotReduction
765+
NonOutputDimNotReduction,
766+
EmptyConvolvedDims
766767
};
767768
} // namespace mlir::linalg::detail
768769

769770
mlir::linalg::detail::MatchConvolutionResult
770771
mlir::linalg::detail::isConvolutionInterfaceImpl(
771-
Operation *op, ConvolutionDimensions *dimensions) {
772+
Operation *op, ConvolutionDimensions *dimensions,
773+
bool allowEmptyConvolvedDims) {
772774
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
773775
if (!linalgOp)
774776
return MatchConvolutionResult::NotLinalgOp;
@@ -886,10 +888,12 @@ mlir::linalg::detail::isConvolutionInterfaceImpl(
886888
if (allLoopDims.size() != linalgOp.getNumLoops())
887889
return MatchConvolutionResult::NonConvolutionLoop;
888890

891+
if (!allowEmptyConvolvedDims && inputExprWalker.convolvedDims.empty())
892+
return MatchConvolutionResult::EmptyConvolvedDims;
893+
889894
if (dimensions) {
890-
FailureOr<ConvolutionDimensions> res =
891-
inferConvolutionDimsImpl(linalgOp, inputExprWalker,
892-
/*allowEmptyConvolvedDims=*/true);
895+
FailureOr<ConvolutionDimensions> res = inferConvolutionDimsImpl(
896+
linalgOp, inputExprWalker, allowEmptyConvolvedDims);
893897
assert(succeeded(res) && "unexpected failure to infer convolution dims");
894898
*dimensions = *res;
895899
}
@@ -920,8 +924,10 @@ mlir::linalg::detail::getMatchConvolutionMessage(MatchConvolutionResult res) {
920924
llvm_unreachable("unhandled MatchConvolutionResult case");
921925
}
922926

923-
bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp) {
924-
return linalg::detail::isConvolutionInterfaceImpl(linalgOp.getOperation()) ==
927+
bool mlir::linalg::isaConvolutionOpInterface(LinalgOp linalgOp,
928+
bool allowEmptyConvolvedDims) {
929+
return linalg::detail::isConvolutionInterfaceImpl(
930+
linalgOp.getOperation(), nullptr, allowEmptyConvolvedDims) ==
925931
linalg::detail::MatchConvolutionResult::Success;
926932
}
927933

0 commit comments

Comments
 (0)