-
Notifications
You must be signed in to change notification settings - Fork 14.2k
[mlir][linalg] Fix weight dimension ordering in 2D grouped conv #73855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The `conv_2d_ngchw_fgchw` Op implements 2d grouped convolution with dimensions ordered as given in the name. However, the current implementation orders weights as `gfchw` instead of `fgchw`. This was already pointed out in an old phabricator revision which never landed: https://reviews.llvm.org/D150064 This patch 1) Adds a new op `conv_2d_ngchw_gfchw` 2) Fixes the dimension ordering of the old op `conv_2d_ngchw_fgchw` 3) Adds tests with non-dynamic dimensions so that it's easier to understand.
@llvm/pr-subscribers-mlir-linalg Author: Felix Schneider (ubfx) ChangesThe This patch
Full diff: https://github.com/llvm/llvm-project/pull/73855.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 12d520cd382413a..1ff6c4086cf3576 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2911,7 +2911,106 @@ structured_op: !LinalgStructuredOpConfig
kind: output_tensor
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
- (s0, s11, s1, s3, s7)>
+ (s0, s1, s11, s3, s7)>
+ - !LinalgOperandDefConfig
+ name: strides
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s4, s8)>
+ default_indices:
+ - 1
+ - 1
+ - !LinalgOperandDefConfig
+ name: dilations
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s6, s10)>
+ default_indices:
+ - 1
+ - 1
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d0, d1, d5, d3 * s4 + d6 * s6, d4 * s8 + d7 * s10)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d2, d1, d5, d6, d7)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: K
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: conv_2d_ngchw_gfchw
+ cpp_class_name: Conv2DNgchwGfchwOp
+ doc: |-
+ Performs 2-D grouped convolution.
+
+ Layout:
+ * Input: NGCHW.
+ * Kernel: GFCHW.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1, s2, s3 * s4 + s5 * s6, s7 * s8 + s9 * s10)>
+ - !LinalgOperandDefConfig
+ name: K
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s1, s11, s2, s5, s9)>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1, s11, s3, s7)>
- !LinalgOperandDefConfig
name: strides
kind: index_attr
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 62b7da2ae2b5337..5b05364f6d35f3b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -780,7 +780,7 @@ def conv_2d_ngchw_fgchw(
T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
),
K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
- O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
+ O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
):
@@ -790,6 +790,32 @@ def conv_2d_ngchw_fgchw(
* Input: NGCHW.
* Kernel: FGCHW.
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
+ O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
+ U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+ ) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw])
+
+
+@linalg_structured_op
+def conv_2d_ngchw_gfchw(
+ I=TensorDef(
+ T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
+ ),
+ K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW),
+ O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs 2-D grouped convolution.
+
+ Layout:
+ * Input: NGCHW.
+ * Kernel: GFCHW.
+
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 5ca35155854d332..29977a71dbb8644 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -409,6 +409,38 @@ func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x
// -----
+// CHECK-LABEL: func @conv_2d_ngchw_fgchw_dimensions
+func.func @conv_2d_ngchw_fgchw_dimensions(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<2x5x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+ // CHECK: linalg.conv_2d_ngchw_fgchw
+ // CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+ // CHECK-SAME: strides = dense<1> : tensor<2xi64>
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+ %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
+ outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+ return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_2d_ngchw_gfchw
+func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<5x2x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+ // CHECK: linalg.conv_2d_ngchw_gfchw
+ // CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+ // CHECK-SAME: strides = dense<1> : tensor<2xi64>
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+ %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+ outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+ return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @conv_3d_ndhwc_dhwcf
func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
// CHECK: %{{.+}} = linalg.conv_3d_ndhwc_dhwcf
|
@llvm/pr-subscribers-mlir Author: Felix Schneider (ubfx) ChangesThe This patch
Full diff: https://github.com/llvm/llvm-project/pull/73855.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 12d520cd382413a..1ff6c4086cf3576 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -2911,7 +2911,106 @@ structured_op: !LinalgStructuredOpConfig
kind: output_tensor
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
- (s0, s11, s1, s3, s7)>
+ (s0, s1, s11, s3, s7)>
+ - !LinalgOperandDefConfig
+ name: strides
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s4, s8)>
+ default_indices:
+ - 1
+ - 1
+ - !LinalgOperandDefConfig
+ name: dilations
+ kind: index_attr
+ index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+ -> (s6, s10)>
+ default_indices:
+ - 1
+ - 1
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d0, d1, d5, d3 * s4 + d6 * s6, d4 * s8 + d7 * s10)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d2, d1, d5, d6, d7)>
+ - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+ s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+ iterator_types:
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - parallel
+ - reduction
+ - reduction
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: O
+ value: !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: O
+ - !ScalarExpression
+ scalar_fn:
+ kind: binary
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: I
+ - !ScalarExpression
+ scalar_fn:
+ kind: type
+ fn_name: cast_signed
+ type_var: U
+ operands:
+ - !ScalarExpression
+ scalar_arg: K
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: conv_2d_ngchw_gfchw
+ cpp_class_name: Conv2DNgchwGfchwOp
+ doc: |-
+ Performs 2-D grouped convolution.
+
+ Layout:
+ * Input: NGCHW.
+ * Kernel: GFCHW.
+
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ implements:
+ - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !LinalgOperandDefConfig
+ name: I
+ kind: input_tensor
+ type_var: T1
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1, s2, s3 * s4 + s5 * s6, s7 * s8 + s9 * s10)>
+ - !LinalgOperandDefConfig
+ name: K
+ kind: input_tensor
+ type_var: T2
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s1, s11, s2, s5, s9)>
+ - !LinalgOperandDefConfig
+ name: O
+ kind: output_tensor
+ type_var: U
+ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+ (s0, s1, s11, s3, s7)>
- !LinalgOperandDefConfig
name: strides
kind: index_attr
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 62b7da2ae2b5337..5b05364f6d35f3b 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -780,7 +780,7 @@ def conv_2d_ngchw_fgchw(
T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
),
K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
- O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
+ O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
):
@@ -790,6 +790,32 @@ def conv_2d_ngchw_fgchw(
* Input: NGCHW.
* Kernel: FGCHW.
+ Numeric casting is performed on the operands to the inner multiply, promoting
+ them to the same data type as the accumulator/output.
+ """
+ implements(ConvolutionOpInterface)
+ domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
+ O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
+ U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
+ ) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw])
+
+
+@linalg_structured_op
+def conv_2d_ngchw_gfchw(
+ I=TensorDef(
+ T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
+ ),
+ K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW),
+ O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
+ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
+ dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
+):
+ """Performs 2-D grouped convolution.
+
+ Layout:
+ * Input: NGCHW.
+ * Kernel: GFCHW.
+
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output.
"""
diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir
index 5ca35155854d332..29977a71dbb8644 100644
--- a/mlir/test/Dialect/Linalg/named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/named-ops.mlir
@@ -409,6 +409,38 @@ func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x
// -----
+// CHECK-LABEL: func @conv_2d_ngchw_fgchw_dimensions
+func.func @conv_2d_ngchw_fgchw_dimensions(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<2x5x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+ // CHECK: linalg.conv_2d_ngchw_fgchw
+ // CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+ // CHECK-SAME: strides = dense<1> : tensor<2xi64>
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+ %0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
+ outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+ return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @conv_2d_ngchw_gfchw
+func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<5x2x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
+ // CHECK: linalg.conv_2d_ngchw_gfchw
+ // CHECK-SAME: dilations = dense<1> : tensor<2xi64>
+ // CHECK-SAME: strides = dense<1> : tensor<2xi64>
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+ // CHECK-SAME: outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+ %0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
+ strides = dense<1> : tensor<2xi64>}
+ ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
+ outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
+ return %0 : tensor<1x5x2x30x30xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @conv_3d_ndhwc_dhwcf
func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
// CHECK: %{{.+}} = linalg.conv_3d_ndhwc_dhwcf
|
…ion ordering (#2623) The linalg Op `linalg.conv_2d_ngchw_fgchw` had a bug where 1. Weights were accessed as G,F,C,H,W instead of as F,G,C,H,W 2. Output was accessed as N,F,G,H,W instead of as N,G,F,H,W Now this has been fixed in llvm/llvm-project#73855 which broke the torch-mlir lowering to that Op. This patch switches lowering in torch-mlir to the newly introduced `linalg.conv_2d_ngchw_gfchw` op which accesses weights in an order that is compatible with PyTorch's memory layout. Fix #2622
The
conv_2d_ngchw_fgchw
Op implements 2d grouped convolution with dimensions ordered as given in the name. However, the current implementation orders weights asgfchw
instead offgchw
. This was already pointed out in an old phabricator revision which never landed: https://reviews.llvm.org/D150064This patch
conv_2d_ngchw_gfchw
conv_2d_ngchw_fgchw
understand.