Skip to content

[mlir][linalg] Fix weight dimension ordering in 2D grouped conv #73855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2023

Conversation

ubfx
Copy link
Member

@ubfx ubfx commented Nov 29, 2023

The conv_2d_ngchw_fgchw Op implements 2d grouped convolution with dimensions ordered as given in the name. However, the current implementation orders weights as gfchw instead of fgchw. This was already pointed out in an old phabricator revision which never landed: https://reviews.llvm.org/D150064

This patch

  1. Adds a new op conv_2d_ngchw_gfchw
  2. Fixes the dimension ordering of the old op conv_2d_ngchw_fgchw
  3. Adds tests with non-dynamic dimensions so that it's easier to
    understand.

The `conv_2d_ngchw_fgchw` Op implements 2d grouped convolution
with dimensions ordered as given in the name. However, the current
implementation orders weights as `gfchw` instead of `fgchw`.
This was already pointed out in an old phabricator revision which
never landed: https://reviews.llvm.org/D150064

This patch
1) Adds a new op `conv_2d_ngchw_gfchw`
2) Fixes the dimension ordering of the old op `conv_2d_ngchw_fgchw`
3) Adds tests with non-dynamic dimensions so that it's easier to
  understand.
@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-mlir-linalg

Author: Felix Schneider (ubfx)

Changes

The conv_2d_ngchw_fgchw Op implements 2d grouped convolution with dimensions ordered as given in the name. However, the current implementation orders weights as gfchw instead of fgchw. This was already pointed out in an old phabricator revision which never landed: https://reviews.llvm.org/D150064

This patch

  1. Adds a new op conv_2d_ngchw_gfchw
  2. Fixes the dimension ordering of the old op conv_2d_ngchw_fgchw
  3. Adds tests with non-dynamic dimensions so that it's easier to
    understand.

Full diff: https://github.com/llvm/llvm-project/pull/73855.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+100-1)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+27-1)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+32)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 12d520cd382413a..1ff6c4086cf3576 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2911,7 +2911,106 @@ structured_op: !LinalgStructuredOpConfig
     kind: output_tensor
     type_var: U
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s0, s11, s1, s3, s7)>
+      (s0, s1, s11, s3, s7)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s4, s8)>
+    default_indices:
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s6, s10)>
+    default_indices:
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1, d5, d3 * s4 + d6 * s6, d4 * s8 + d7 * s10)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d2, d1, d5, d6, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: I
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: K
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_ngchw_gfchw
+  cpp_class_name: Conv2DNgchwGfchwOp
+  doc: |-
+    Performs 2-D grouped convolution.
+
+    Layout:
+      * Input: NGCHW.
+      * Kernel: GFCHW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s2, s3 * s4 + s5 * s6, s7 * s8 + s9 * s10)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s1, s11, s2, s5, s9)>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s11, s3, s7)>
   - !LinalgOperandDefConfig
     name: strides
     kind: index_attr
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 62b7da2ae2b5337..5b05364f6d35f3b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -780,7 +780,7 @@ def conv_2d_ngchw_fgchw(
         T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
     ),
     K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
-    O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
+    O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
 ):
@@ -790,6 +790,32 @@ def conv_2d_ngchw_fgchw(
       * Input: NGCHW.
       * Kernel: FGCHW.
 
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw])
+
+
+@linalg_structured_op
+def conv_2d_ngchw_gfchw(
+    I=TensorDef(
+        T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
+    ),
+    K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D grouped convolution.
+
+    Layout:
+      * Input: NGCHW.
+      * Kernel: GFCHW.
+
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 5ca35155854d332..29977a71dbb8644 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -409,6 +409,38 @@ func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x
 
 // -----
 
+// CHECK-LABEL: func @conv_2d_ngchw_fgchw_dimensions
+func.func @conv_2d_ngchw_fgchw_dimensions(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<2x5x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+  // CHECK:      linalg.conv_2d_ngchw_fgchw
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
+    outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_2d_ngchw_gfchw
+func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<5x2x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+  // CHECK:      linalg.conv_2d_ngchw_gfchw
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+    outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @conv_3d_ndhwc_dhwcf
 func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
   // CHECK:      %{{.+}} = linalg.conv_3d_ndhwc_dhwcf

@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-mlir

Author: Felix Schneider (ubfx)

Changes

The conv_2d_ngchw_fgchw Op implements 2d grouped convolution with dimensions ordered as given in the name. However, the current implementation orders weights as gfchw instead of fgchw. This was already pointed out in an old phabricator revision which never landed: https://reviews.llvm.org/D150064

This patch

  1. Adds a new op conv_2d_ngchw_gfchw
  2. Fixes the dimension ordering of the old op conv_2d_ngchw_fgchw
  3. Adds tests with non-dynamic dimensions so that it's easier to
    understand.

Full diff: https://github.com/llvm/llvm-project/pull/73855.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+100-1)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+27-1)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+32)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 12d520cd382413a..1ff6c4086cf3576 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2911,7 +2911,106 @@ structured_op: !LinalgStructuredOpConfig
     kind: output_tensor
     type_var: U
     shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
-      (s0, s11, s1, s3, s7)>
+      (s0, s1, s11, s3, s7)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s4, s8)>
+    default_indices:
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s6, s10)>
+    default_indices:
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1, d5, d3 * s4 + d6 * s6, d4 * s8 + d7 * s10)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d2, d1, d5, d6, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: I
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: K
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_ngchw_gfchw
+  cpp_class_name: Conv2DNgchwGfchwOp
+  doc: |-
+    Performs 2-D grouped convolution.
+
+    Layout:
+      * Input: NGCHW.
+      * Kernel: GFCHW.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s2, s3 * s4 + s5 * s6, s7 * s8 + s9 * s10)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s1, s11, s2, s5, s9)>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s11, s3, s7)>
   - !LinalgOperandDefConfig
     name: strides
     kind: index_attr
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 62b7da2ae2b5337..5b05364f6d35f3b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -780,7 +780,7 @@ def conv_2d_ngchw_fgchw(
         T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
     ),
     K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
-    O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
+    O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
     strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
     dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
 ):
@@ -790,6 +790,32 @@ def conv_2d_ngchw_fgchw(
       * Input: NGCHW.
       * Kernel: FGCHW.
 
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+    """
+    implements(ConvolutionOpInterface)
+    domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
+    O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
+        U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+    ) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw])
+
+
+@linalg_structured_op
+def conv_2d_ngchw_gfchw(
+    I=TensorDef(
+        T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
+    ),
+    K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW),
+    O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
+    strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+    dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+    """Performs 2-D grouped convolution.
+
+    Layout:
+      * Input: NGCHW.
+      * Kernel: GFCHW.
+
     Numeric casting is performed on the operands to the inner multiply, promoting
     them to the same data type as the accumulator/output.
     """
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 5ca35155854d332..29977a71dbb8644 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -409,6 +409,38 @@ func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x
 
 // -----
 
+// CHECK-LABEL: func @conv_2d_ngchw_fgchw_dimensions
+func.func @conv_2d_ngchw_fgchw_dimensions(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<2x5x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+  // CHECK:      linalg.conv_2d_ngchw_fgchw
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
+    outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_2d_ngchw_gfchw
+func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<5x2x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+  // CHECK:      linalg.conv_2d_ngchw_gfchw
+  // CHECK-SAME:   dilations = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   strides = dense<1> : tensor<2xi64>
+  // CHECK-SAME:   ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+  // CHECK-SAME:   outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
+                                         strides = dense<1> : tensor<2xi64>}
+     ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+    outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+  return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @conv_3d_ndhwc_dhwcf
 func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
   // CHECK:      %{{.+}} = linalg.conv_3d_ndhwc_dhwcf

@ubfx ubfx merged commit ed22bf6 into llvm:main Dec 1, 2023
@ubfx ubfx deleted the linalg-grouped-convolution-dimorder branch December 8, 2023 11:26
ubfx added a commit to llvm/torch-mlir that referenced this pull request Dec 8, 2023
…ion ordering (#2623)

The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where

1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W
2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W

Now this has been fixed in
llvm/llvm-project#73855 which broke the
torch-mlir lowering to that Op.

This patch switches lowering in torch-mlir to the newly introduced
`linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that
is compatible with PyTorch's memory layout.

Fix #2622
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants