Skip to content

[MLIR] Refactor to create vectorization convOp precondition check #130181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 17, 2025

Conversation

jerryyin
Copy link
Member

@jerryyin jerryyin commented Mar 6, 2025

In corner situations, the vectorization pass may face to lower a conv2d op and assert in a completely irrelevant location in vectorizeConvolution() subroutine.

This PR rejects the conv2d op early and make the asserted routine to return failure as a defensive workaround.

In addressing this, the PR moved all condition check away from the Conv1dGenerator into the convOpPreconditionCheck() function. This makes the unsupported ops such as conv2d to be rejected early and leave a cleaner Conv1dGenerator constructor.

@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Zhuoran Yin (jerryyin)

Changes

In corner situations, the vectorization pass may face to lower a conv2d op and assert in a completely irrelevant location in vectorizeConvolution() subroutine.

This PR rejects the conv2d op early and make the asserted routine to return failure as a defensive workaround.


Full diff: https://github.com/llvm/llvm-project/pull/130181.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+16-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization-unsupported.mlir (+19)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..319dd4b2043c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1990,8 +1990,18 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   // TODO: isaConvolutionOpInterface that can also infer from generic
   // features. But we will still need stride/dilation attributes that will be
   // annoying to reverse-engineer...
-  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+    // Check if it is 2d+ convolution. If it is, return failure because we don't
+    // support it. To use this pass on a 2d+ convolution, it should have already
+    // been decomposed to 1d convolution via
+    // DecomposeConvolutionToLowerDimOpsPass.
+    if (linalgOp.getNumParallelLoops() >= 4) {
+      LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
+      return failure();
+    }
     return success();
+  }
+
   // TODO: the common vector shape is equal to the static loop sizes only when
   // all indexing maps are projected permutations. For convs and stencils the
   // logic will need to evolve.
@@ -3929,9 +3939,11 @@ static FailureOr<Operation *> vectorizeConvolution(
   if (!inputVecSizes.empty()) {
     // Only use the input vector size corresponding to the channel dim. Other
     // vector dims will be inferred from the Ops.
-    assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
-            isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
-           "Not a 1D depthwise conv!");
+    if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
+        !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
+      return rewriter.notifyMatchFailure(
+          op, "Unexpected convolution: expected 1D depthwise conv");
+    }
     size_t chDimIdx =
         TypeSwitch<Operation *, size_t>(op)
             .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 8f3b199145ce0..88d9e98c02bca 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -112,6 +112,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4:  tensor<64x64x3x3xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %5 = tensor.empty() : tensor<1x64x56x56xf32>
+  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
   %pad = arith.constant 0.000000e+00 : f32
   // expected-error @+1 {{Attempted to vectorize, but failed}}

@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2025

@llvm/pr-subscribers-mlir

Author: Zhuoran Yin (jerryyin)

Changes

In corner situations, the vectorization pass may face to lower a conv2d op and assert in a completely irrelevant location in vectorizeConvolution() subroutine.

This PR rejects the conv2d op early and make the asserted routine to return failure as a defensive workaround.


Full diff: https://github.com/llvm/llvm-project/pull/130181.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+16-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization-unsupported.mlir (+19)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ae04c2b6b2a5b..319dd4b2043c3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1990,8 +1990,18 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   // TODO: isaConvolutionOpInterface that can also infer from generic
   // features. But we will still need stride/dilation attributes that will be
   // annoying to reverse-engineer...
-  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
+  if (isa<ConvolutionOpInterface>(linalgOp.getOperation())) {
+    // Check if it is 2d+ convolution. If it is, return failure because we don't
+    // support it. To use this pass on a 2d+ convolution, it should have already
+    // been decomposed to 1d convolution via
+    // DecomposeConvolutionToLowerDimOpsPass.
+    if (linalgOp.getNumParallelLoops() >= 4) {
+      LDBG("precondition failed: Regular 2d+ convolutions not supported.\n");
+      return failure();
+    }
     return success();
+  }
+
   // TODO: the common vector shape is equal to the static loop sizes only when
   // all indexing maps are projected permutations. For convs and stencils the
   // logic will need to evolve.
@@ -3929,9 +3939,11 @@ static FailureOr<Operation *> vectorizeConvolution(
   if (!inputVecSizes.empty()) {
     // Only use the input vector size corresponding to the channel dim. Other
     // vector dims will be inferred from the Ops.
-    assert((isa<linalg::DepthwiseConv1DNwcWcOp>(*op) ||
-            isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) &&
-           "Not a 1D depthwise conv!");
+    if (!isa<linalg::DepthwiseConv1DNwcWcOp>(*op) &&
+        !isa<linalg::DepthwiseConv1DNcwCwOp>(*op)) {
+      return rewriter.notifyMatchFailure(
+          op, "Unexpected convolution: expected 1D depthwise conv");
+    }
     size_t chDimIdx =
         TypeSwitch<Operation *, size_t>(op)
             .Case<linalg::DepthwiseConv1DNwcWcOp>([](auto conv) { return 2; })
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
index 8f3b199145ce0..88d9e98c02bca 100644
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -112,6 +112,25 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+func.func @conv2d(%3: tensor<1x64x58x58xf32>, %4:  tensor<64x64x3x3xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  %5 = tensor.empty() : tensor<1x64x56x56xf32>
+  %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+  // expected-error @+1 {{Attempted to vectorize, but failed}}
+  %7 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%3, %4 : tensor<1x64x58x58xf32>, tensor<64x64x3x3xf32>) outs(%6 : tensor<1x64x56x56xf32>) -> tensor<1x64x56x56xf32>
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.conv_2d_nchw_fchw"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
 func.func @test_pack_no_vectorize_dynamic_shape(%arg0: tensor<?xf32>, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> {
   %pad = arith.constant 0.000000e+00 : f32
   // expected-error @+1 {{Attempted to vectorize, but failed}}

@jerryyin jerryyin changed the title [MLIR] Allowing unsupported conv2d op to fail gracefully vectorization pass [MLIR] Allowing unsupported conv2d op to fail gracefully from vectorization Mar 6, 2025
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jerryyin , convs are in real need of TLC 🙏🏻

Now, IMO, what's missing is a very high-level check in vectorizeOpPrecondition. From what I can tell, it would be totally fine to check the dims of conv ops very early on.

With your current approach, I just feel that the current fairly "custom" logic is replaced with something else that is also fairly custom. Instead, I am proposing to re-use the high-level logic that's already available.

Thanks for helping us improve this!

// For example, if it is 2d+ convolution, return failure because we don't
// support it. To use this pass on a 2d+ convolution, it should have already
// been decomposed to 1d convolution via
// DecomposeConvolutionToLowerDimOpsPass.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I couldn't find such Pass in-tree.

Copy link
Member Author

@jerryyin jerryyin Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologize, this is an iree implementation detail and referenced pass is also from IREE. I shouldn't include such pass in upstream. Will remove.

I am proposing to re-use the high-level logic that's already available.

Let me use this thread to discuss it. This diff code block is the high-level pre condition check around linalg op. Please let me know if you are referring to a different location.

From what I can tell, it would be totally fine to check the dims of conv ops very early on.

Could you elaborate? Are you referring to explicitly invoke inferConvolutionDims()? Then to make sure this is a regular 2d convolution, I'd check for a combination of:

  • outputImage.size() == 2
  • batch.size() == 1
  • outputChannel.size() == 1

Reject if all of those satisfy. I have no problem to implement this, but just want to make sure we are on the same page.


Taking a step back, I don't have a lot of context about the history of the vectorization code around convolution. Since this PR is not intending to do a massive re-write, I'm attempting to be coherent with the existing code as much as possible.

One thing I've noticed and @hanhanW who righteously pointed out is that we can fail to build a Conv1DGenerator, and still allow a function (like how vectorizeConvolution() construct and uses the Conv1DGenerator) invoked on its member vectorization functions, which I find to be quite confusing. (If I'm to implement this from scratch, I'll probably use singleton + initialize compared to the approach (constructor + valid member variable). This way, a developer is required to invoke the initialize method and check validity of the class before invoking anything on it.)

With this context, I find the most defensive approach is the one used from this PR right now:

  • With future implementation to be added and more flavor of convolution supported, it is very likely that the precondition check on vectorize convolution grow out of sync (and this PR is a perfect example)
  • Now instead of maintain a separate function that does a subset of the constructor logic, why not re-use it and ensure we do the validity check? This looks reasonable as the constructor is (if not better, at least) not more expensive than having to infer the convolution dimensions.

With above reasoning added up, it just looks to me to be a better solution compared with inferring the convolution dimensions and reject a few corner cases (which can easily grow out-of-sync later).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me use this thread to discuss it. This diff code block is the high-level pre condition check around linalg op. Please let me know if you are referring to a different location.

Similarly to Diego, I am suggesting to move the high level logic to vectorizeOpPrecondition. Also, like other "pre-condition" hooks, it should not require a rewriter.

Could you elaborate? Are you referring to explicitly invoke inferConvolutionDims()? Then to make sure this is a regular 2d convolution, I'd check for a combination of:

  • outputImage.size() == 2
  • batch.size() == 1
  • outputChannel.size() == 1

Reject if all of those satisfy. I have no problem to implement this, but just want to make sure we are on the same page.

In my naivety, I was hoping that checking e.g. the rank of the filter or the input would be sufficient. But clearly not - the input for non-channelled conv would be 1D, but for a channeled one would be 2D. So on and so forth. IMO, you can just create something like this:

if (!isa<conv_type1, conv_type2, ...>(conv))
   return failure();

This will be a bit verbose, but there's just too many convs and whatever we try will be ... verbose 🤷🏻

Taking a step back, I don't have a lot of context about the history of the vectorization code around convolution. Since this PR is not intending to do a massive re-write, I'm attempting to be coherent with the existing code as much as possible.

+1 to being coherent, thanks!

I was actually going to ask - do you have any plans regarding this code beyond this PR?

One thing I've noticed and @hanhanW who righteously pointed out is that we can fail to build a Conv1DGenerator, and still allow a function (like how vectorizeConvolution() construct and uses the Conv1DGenerator) invoked on its member vectorization functions, which I find to be quite confusing. (If I'm to implement this from scratch, I'll probably use singleton + initialize compared to the approach (constructor + valid member variable). This way, a developer is required to invoke the initialize method and check validity of the class before invoking anything on it.)

You should be able to simply add:

assert(isValid() && "Conv1DGenerator failed")

From what I can tell, that wouldn't break any tests and will make "validity" a strong pre-requisite.

With this context, I find the most defensive approach is the one used from this PR right now:

  • With future implementation to be added and more flavor of convolution supported, it is very likely that the precondition check on vectorize convolution grow out of sync (and this PR is a perfect example)

There's been no new implementations in > 2 yrs. From what I can tell, we can safely assume that this will remain the case for the foreseeable future. So, I wouldn't worry about this.

  • Now instead of maintain a separate function that does a subset of the constructor logic, why not re-use it and ensure we do the validity check? This looks reasonable as the constructor is (if not better, at least) not more expensive than having to infer the convolution dimensions.

That sounds good in theory, but in practice it means that we need an IR writer for the validation. "Validation"/"pre-conditioning" should not require a rewriter.

With above reasoning added up, it just looks to me to be a better solution compared with inferring the convolution dimensions and reject a few corner cases (which can easily grow out-of-sync later).

How about my suggestion with isa?

Copy link
Member Author

@jerryyin jerryyin Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really appreciate your thorough review comments which gives me a ton of useful information.

I was actually going to ask - do you have any plans regarding this code beyond this PR?

Thanks for asking! I don't have any further plans... Was only meant to unblock myself from a non-relevant crash that will fail downstream tests.

That sounds good in theory, but in practice it means that we need an IR writer for the validation. "Validation"/"pre-conditioning" should not require a rewriter.

Agreed that I don't like to have a redundant dummy rewriter just for the validation too. In fact, I took a second look at all the instances of places where a Conv1DGenerator's member function is invoked and find that all the places have access to a rewriter. The need for the rewriter really only comes from the base class StructuredGenerator constructor. Then, I'm also surprised to find that the base class StructuredGenerator doesn't use the rewriter yet it unnecessarily stored this as the state to this class. A slightly more aggressive way is to get rid of the field from base class and move rewriter to base class on case by case manner. Then we'd have a clean way to construct it without requiring a rewriter. Sounds like a rabbit hole that I'd avoid from this PR :-p

How about my suggestion with isa?

I'll adopt this. This is a cheap enough check that seems reasonable for pre-condition check. Although I'll refrain from being "complete" in this check because in reality, the linalg.conv_2d_* and linalg.conv3d_* is a really long list combined, with quantized, groupd and non-channel variants. I'm going to leave those other variants out and check for only simple conv2d and 3d cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly to Diego, I am suggesting to move the high level logic to vectorizeOpPrecondition. Also, like other "pre-condition" hooks, it should not require a rewriter.

I thought that it is moved to the vectorizeOpPrecondition in the PR? The check is in vectorizeLinalgOpPrecondition and the former one calls this function. Do you suggest creating a different function like vectorizeConvPrecondition, and we use it in vectorizeOpPrecondition? It is okay to me because convolution really goes with a different path.

RE verification issue: I totally agree that the verification should not depend on an IR rewriter. From what I can tell, we do not need it at all. The class needs it for StructuredGenerator, but we dont need it in the verfication at all.

// Determine whether `linalgOp` can be generated with this generator
if (linalgOp.getNumDpsInputs() != 2 || linalgOp.getNumDpsInits() != 1)
return;
lhsShaped = linalgOp.getDpsInputOperand(0)->get();
rhsShaped = linalgOp.getDpsInputOperand(1)->get();
resShaped = linalgOp.getDpsInitOperand(0)->get();
lhsShapedType = dyn_cast<ShapedType>(lhsShaped.getType());
rhsShapedType = dyn_cast<ShapedType>(rhsShaped.getType());
resShapedType = dyn_cast<ShapedType>(resShaped.getType());
if (!lhsShapedType || !rhsShapedType || !resShapedType)
return;
// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
// (non-channeled convolution -> LHS and RHS both have single dimensions).
if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
(lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
return;
Operation *reduceOp = matchLinalgReduction(linalgOp.getDpsInitOperand(0));
if (!reduceOp)
return;
redOp = reduceOp->getName().getIdentifier();
if (!setOperKind(reduceOp))
return;
auto maybeKind = getCombinerOpKind(reduceOp);
// Typically convolution will have a `Add` CombiningKind but for i1 type it
// can get strength reduced to `OR` which is also supported. This strength
// reduction logic is in `buildBinaryFn` helper in the Linalg dialect.
if (!maybeKind || ((*maybeKind != vector::CombiningKind::ADD &&
*maybeKind != vector::CombiningKind::OR) &&
(oper != Pool || !isSupportedPoolKind(*maybeKind)))) {
return;
}
reductionKind = maybeKind.value();
auto rhsRank = rhsShapedType.getRank();
switch (oper) {
case Conv:
if (rhsRank != 1 && rhsRank != 2 && rhsRank != 3)
return;
break;
case Pool:
if (rhsRank != 1)
return;
break;
}
// The op is now known to be valid.
valid = true;

The valid variable is only used in assertions in few methods, e.g., depthwiseConv and conv. I think it's mainly created for sanity check, while the new codes did not take it into account. Thus, we crashed in the other place.

The code is quite old and the precondition was added later than the conv code. I think to make it in better structure, we can refactor the generator because everything is started from the generator. How about we have a static class method which returns true when the given operation is supported? That said, we move the above logic check to a static method (e.g., vectorizePrecondition) without initializing any variables.

In the construction, I'd suggest doing simple things as much as possible. And we move the assertion out of the constructor. In the context, they are moved to an initializer method. Because I'd prefer avoiding a crash in the constructor, and we can expose the failure handling to external users. (I don't know what the style is in LLVM, but it is quite common in environments where exceptions are disallowed. See https://abseil.io/tips/42 for more details.)

Thus, it can be something like

Conv1DGenerator : : public StructuredGenerator<LinalgOp, utils::IteratorType> {
// constructor only takes the rewriter and linalgop
Conv1DGenerator(RewriterBase &rewriter, LinalgOp linalgOp) : StructuredGenerator<LinalgOp, utils::IteratorType>(rewriter, linalgOp) {}

// vectorization precond
bool/LogicalResult vectorizePrecondition(LinalgOp linalgOp) { ... }

// The initialization method
LogicalResult init() {
  // or do an assertion here.
  if (failed(vectorizedPrecondition(...))) {
    return failure();
  }
  // Initial the values for class members.
}

Does it look better structured?

@jerryyin jerryyin force-pushed the users/zyin/fix-vectorization-masking-convolution branch from aeaf451 to 74a8986 Compare March 10, 2025 18:39
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am happy with the current approach, thank you @jerryyin ! I've left one comment, but that's an optional nice-to-have. LGTM!

Please wait for @dcaballe and @hanhanW to approve before merging.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks better to me, I left some style nits. Thanks!

@jerryyin jerryyin force-pushed the users/zyin/fix-vectorization-masking-convolution branch from cdc5625 to 13f5183 Compare March 11, 2025 14:43
@jerryyin jerryyin changed the title [MLIR] Allowing unsupported conv2d op to fail gracefully from vectorization [MLIR] Refactor to create vectorization convOp precondition check Mar 11, 2025
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my comments! Looks good to me. Just one nit about function comment. Please also wait a review/response from @banach-space because the approval is for the previous revision.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Much better. I added a few more comments

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the refactor - it is similar to one of the options that I had in mind

For me, right now it's tricky to tell that it's indeed only 1D Convs that are supported. Is there an easy way to extract that info from vectorizeConvOpPrecondition that I am failing to see? 😅

@jerryyin jerryyin force-pushed the users/zyin/fix-vectorization-masking-convolution branch from 51a97cc to a58b8da Compare March 13, 2025 13:45
@llvm llvm deleted a comment from github-actions bot Mar 13, 2025
@jerryyin
Copy link
Member Author

@banach-space Sorry I somehow missed the main block of your comments initially.

tricky to tell that it's indeed only 1D Convs that are supported

Please refer to the first rank check from vectorizeConvOpPrecondition() below:

// (LHS has dimension NCW/NWC and RES has dimension NFW/NCW/NWF/NWC) OR
// (non-channeled convolution -> LHS and RHS both have single dimensions).
if ((lhsShapedType.getRank() != 3 || resShapedType.getRank() != 3) &&
(lhsShapedType.getRank() != 1 || resShapedType.getRank() != 1))
return failure();

For a 2d conv op like linalg.conv_2d_nchw_fchw, this will return failure because both lhs and res shape rank is 4, not 1 or 3. Do you think this is good enough condition?


I appreciate everyone's review and feedback. Thanks you all very much for pushing towards a better solution!

Please feel free to leave additional feedbacks if there's any. If not, I plan to land this around next Monday.

@banach-space
Copy link
Contributor

Thanks!

For a 2d conv op like linalg.conv_2d_nchw_fchw, this will return failure because both lhs and res shape rank is 4, not 1 or 3. Do you think this is good enough condition?

The condition looks good, but an additional comment would be helpful. It's not immediately obvious that that logic will reject 2D and 3D variants.

Please feel free to leave additional feedbacks if there's any. If not, I plan to land this around next Monday.

SG, I won't be adding any more comments beyond the one above.

Thanks for taking care of this 🙏🏻

@jerryyin jerryyin merged commit 1e89a76 into main Mar 17, 2025
11 checks passed
@jerryyin jerryyin deleted the users/zyin/fix-vectorization-masking-convolution branch March 17, 2025 13:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants