Skip to content

[mlir][linalg] Add a test for inferConvolutionDimsImpl #90057

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Apr 25, 2024

Adds a test for inferConvolutionDimsImpl to exercise the logic for depthwise convs.

ATM, `inferConvolutionDimsImpl` will "remove" "unconvolved" dims from the
calculation of the channel dims. However, that's incorrect for depthwise
convolutions for which the channel dimension falls into that group (i.e.
"unconvolved" dims).
@llvmbot
Copy link
Member

llvmbot commented Apr 25, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

ATM, inferConvolutionDimsImpl will "remove" "unconvolved" dims from the
calculation of the channel dims. However, that's incorrect for depthwise
convolutions for which the channel dimension falls into that group (i.e.
"unconvolved" dims).


Full diff: https://github.com/llvm/llvm-project/pull/90057.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (-1)
  • (modified) mlir/test/Dialect/Linalg/match-ops-interpreter.mlir (+22)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda3..3b92da5ceccd39 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -556,7 +556,6 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
   // filterDims & outputDims - unConvolvedDims are the output channel iterators.
   llvm::SmallDenseSet<int64_t> oc = filterDims;
   llvm::set_intersect(oc, outputDims);
-  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
 
   // filterDims & outputDims & unConvolvedDims are the depth iterators.
   llvm::SmallDenseSet<int64_t> depth = filterDims;
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 24c7bdd9e1050e..c637e1df7efd3e 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -1062,6 +1062,28 @@ module attributes { transform.target_tag = "start_here" } {
     return %result : tensor<10x18x15xf64>
   }
 
+  func.func @convolution_depthwise(%input: tensor<1x10x196x48xf32>, %filter: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> {
+    %cst = arith.constant 0.0 : f32 
+    %empty = tensor.empty() : tensor<1x10x191x48xf32>
+    %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+    // expected-remark @below {{convolution}}
+    // expected-remark @below {{batch dims 0}}
+    // expected-remark @below {{output image dims 1 : i64, 2 : i64}}
+    // expected-remark @below {{output channel dims 3}}
+    // expected-remark @below {{filter loop dims 4 : i64, 5 : i64}}
+    // expected-remark @below {{input channel dims}}
+    // expected-remark @below {{depth dims 3}}
+    // expected-remark @below {{strides 1 : i64, 1 : i64}}
+    // expected-remark @below {{dilations 1 : i64, 1 : i64}}
+    %result = linalg.depthwise_conv_2d_nhwc_hwc {
+      dilations = dense<1> : tensor<2xi64>,
+      strides = dense<1> : tensor<2xi64>}
+      ins(%input, %filter : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>)
+      outs(%fill : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+
+    return %result : tensor<1x10x191x48xf32>
+  }
+
   func.func @convolution_multi_channel(%input: tensor<2x34x68x16xf32>, %filter: tensor<8x2x3x5x16x16xf32>) -> tensor<8x32x32x16xf32> {
     %cst = arith.constant 0.0 : f32
     %empty = tensor.empty() : tensor<8x32x32x16xf32>

@llvmbot
Copy link
Member

llvmbot commented Apr 25, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

ATM, inferConvolutionDimsImpl will "remove" "unconvolved" dims from the
calculation of the channel dims. However, that's incorrect for depthwise
convolutions for which the channel dimension falls into that group (i.e.
"unconvolved" dims).


Full diff: https://github.com/llvm/llvm-project/pull/90057.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (-1)
  • (modified) mlir/test/Dialect/Linalg/match-ops-interpreter.mlir (+22)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index 3627ff6617eda3..3b92da5ceccd39 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -556,7 +556,6 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
   // filterDims & outputDims - unConvolvedDims are the output channel iterators.
   llvm::SmallDenseSet<int64_t> oc = filterDims;
   llvm::set_intersect(oc, outputDims);
-  llvm::set_subtract(oc, inputExprWalker.unConvolvedDims);
 
   // filterDims & outputDims & unConvolvedDims are the depth iterators.
   llvm::SmallDenseSet<int64_t> depth = filterDims;
diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
index 24c7bdd9e1050e..c637e1df7efd3e 100644
--- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
+++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir
@@ -1062,6 +1062,28 @@ module attributes { transform.target_tag = "start_here" } {
     return %result : tensor<10x18x15xf64>
   }
 
+  func.func @convolution_depthwise(%input: tensor<1x10x196x48xf32>, %filter: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> {
+    %cst = arith.constant 0.0 : f32 
+    %empty = tensor.empty() : tensor<1x10x191x48xf32>
+    %fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+    // expected-remark @below {{convolution}}
+    // expected-remark @below {{batch dims 0}}
+    // expected-remark @below {{output image dims 1 : i64, 2 : i64}}
+    // expected-remark @below {{output channel dims 3}}
+    // expected-remark @below {{filter loop dims 4 : i64, 5 : i64}}
+    // expected-remark @below {{input channel dims}}
+    // expected-remark @below {{depth dims 3}}
+    // expected-remark @below {{strides 1 : i64, 1 : i64}}
+    // expected-remark @below {{dilations 1 : i64, 1 : i64}}
+    %result = linalg.depthwise_conv_2d_nhwc_hwc {
+      dilations = dense<1> : tensor<2xi64>,
+      strides = dense<1> : tensor<2xi64>}
+      ins(%input, %filter : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>)
+      outs(%fill : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
+
+    return %result : tensor<1x10x191x48xf32>
+  }
+
   func.func @convolution_multi_channel(%input: tensor<2x34x68x16xf32>, %filter: tensor<8x2x3x5x16x16xf32>) -> tensor<8x32x32x16xf32> {
     %cst = arith.constant 0.0 : f32
     %empty = tensor.empty() : tensor<8x32x32x16xf32>

@banach-space
Copy link
Contributor Author

@qedawkins I've not worked much with non-depthwise convs and inferConvolutionDimsImpl is quite ... convolved 😅 So, please could you double check this and let me know if I missed sth?

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah simplifications to the code would be welcome. There is a depth dim classification for this though, maybe would it make more sense if we renamed it to depthChannel? Also I'm not necessarily opposed to a change in classification, I just wrote the current impl to be more in line with the inferContractionDims impl and I haven't worked much with depthwise convs.

// expected-remark @below {{output channel dims 3}}
// expected-remark @below {{filter loop dims 4 : i64, 5 : i64}}
// expected-remark @below {{input channel dims}}
// expected-remark @below {{depth dims 3}}
Copy link
Contributor

@qedawkins qedawkins Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look like what I would expect because it is classifying dim 3 twice. The naming might need to be improved for depthwise convolutions, but the idea with these classifications is that they tell users exactly how each different dimension appears in the indexing maps (and thus input/filter/output tensors). Right now the impl uses this to imply that all dims classified as outputChannel are only featured in the filter and output maps. Dims classified as depth appear in all three of the input, filter, and output.

I'm noticing that there is no description for depth in this comment though:

/// Find at least 1 parallel (output_image) and reduction (filter_loop)
/// dimension candidates that form a convolution subcomputation within
/// `linalgOp`. The LHS is assumed to be the convolution input while the
/// RHS is assumed as the filter.
/// These dimensions are such that:
/// 1. Optional batch dimensions that appear in the input and filter.
/// 2. The output_image dimension is involved in a cross-correlation along LHS
/// (i.e. it is a permutation on RES and LHS and has an associated
/// filter_loop in RHS).
/// 3. Optional output_channel dimension is involved in an outer-product along
/// RHS (i.e. it is a permutation on RES and RHS and does not appear in
/// LHS).
/// 4. Optional input_channel dimension appears as a permutation on LHS and
/// RHS.
/// 5. The filter_loop dimension appears as a permutation on the RHS and
/// represents the shape of the kernel cross-correlated along a
/// corresponding output_image dim.
/// 6. The input_channel dimension appears as a permutation on LHS and RHS.
/// 7. All dimensions appear only once in any given indexing map.
/// This allows e.g. detecting that some convolution is embedded within
/// `linalgOp` with some orthogonal heuristic.
/// When multiple dimension occurrences exist that match any classification
/// indices are returned in sorted order.
/// Returns a failure if `output_image` (and implicitly `filter_loop`) is empty.

(also the comment is kind of convoluted and could probably be rewritten, maybe with an example).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out - now I see what's happening here. IIUC, "channel" is only used for non-depthwise convolutions. For depthwise convs (HWC), this logic uses the term "depth" instead. Now, "depth" is:

  • a parallel dim
  • identical for all inputs.

It looks like this PR is breaking rather than fixing things 😂 Let me update and rephrase.

@banach-space banach-space changed the title [mlir][linalg] Fix inferConvolutionDimsImpl (depthwise convs) [mlir][linalg] Add a test for inferConvolutionDimsImpl Apr 25, 2024
@banach-space
Copy link
Contributor Author

Thanks for taking a look, Quinn! I've rephrased this patch - adding a test for depthwise convs should be sufficient.

Also I'm not necessarily opposed to a change in classification, I just wrote the current impl to be more in line with the inferContractionDims impl and I haven't worked much with depthwise convs.

Your implementation is perfectly fine. I think that what we are really missing is some high-level reference doc that would classify all the convs in Linalg (and normalize the terminology). I don't have the right overview as I've been mostly working with depthwise convs. But this is on my TODO list and I would also be more than happy to review PRs for this. We'll get there eventually!

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for pushing on this! Tests and documentation improvements are always welcome :)

@banach-space banach-space merged commit 3005ca2 into llvm:main Apr 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants