Skip to content

Commit ed22bf6

Browse files
authored
[mlir][linalg] Fix weight dimension ordering in 2D grouped conv (#73855)
The `conv_2d_ngchw_fgchw` Op implements 2d grouped convolution with dimensions ordered as given in the name. However, the current implementation orders weights as `gfchw` instead of `fgchw`. This was already pointed out in an old phabricator revision which never landed: https://reviews.llvm.org/D150064 This patch 1) Adds a new op `conv_2d_ngchw_gfchw` 2) Fixes the dimension ordering of the old op `conv_2d_ngchw_fgchw` 3) Adds tests with non-dynamic dimensions so that it's easier to understand.
1 parent 027935d commit ed22bf6

File tree

3 files changed

+159
-2
lines changed

3 files changed

+159
-2
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2911,7 +2911,106 @@ structured_op: !LinalgStructuredOpConfig
29112911
kind: output_tensor
29122912
type_var: U
29132913
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
2914-
(s0, s11, s1, s3, s7)>
2914+
(s0, s1, s11, s3, s7)>
2915+
- !LinalgOperandDefConfig
2916+
name: strides
2917+
kind: index_attr
2918+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
2919+
-> (s4, s8)>
2920+
default_indices:
2921+
- 1
2922+
- 1
2923+
- !LinalgOperandDefConfig
2924+
name: dilations
2925+
kind: index_attr
2926+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
2927+
-> (s6, s10)>
2928+
default_indices:
2929+
- 1
2930+
- 1
2931+
indexing_maps: !LinalgIndexingMapsConfig
2932+
static_indexing_maps:
2933+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
2934+
s8, s9, s10, s11] -> (d0, d1, d5, d3 * s4 + d6 * s6, d4 * s8 + d7 * s10)>
2935+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
2936+
s8, s9, s10, s11] -> (d2, d1, d5, d6, d7)>
2937+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
2938+
s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
2939+
iterator_types:
2940+
- parallel
2941+
- parallel
2942+
- parallel
2943+
- parallel
2944+
- parallel
2945+
- reduction
2946+
- reduction
2947+
- reduction
2948+
assignments:
2949+
- !ScalarAssign
2950+
arg: O
2951+
value: !ScalarExpression
2952+
scalar_fn:
2953+
kind: binary
2954+
fn_name: add
2955+
operands:
2956+
- !ScalarExpression
2957+
scalar_arg: O
2958+
- !ScalarExpression
2959+
scalar_fn:
2960+
kind: binary
2961+
fn_name: mul
2962+
operands:
2963+
- !ScalarExpression
2964+
scalar_fn:
2965+
kind: type
2966+
fn_name: cast_signed
2967+
type_var: U
2968+
operands:
2969+
- !ScalarExpression
2970+
scalar_arg: I
2971+
- !ScalarExpression
2972+
scalar_fn:
2973+
kind: type
2974+
fn_name: cast_signed
2975+
type_var: U
2976+
operands:
2977+
- !ScalarExpression
2978+
scalar_arg: K
2979+
--- !LinalgOpConfig
2980+
metadata: !LinalgOpMetadata
2981+
name: conv_2d_ngchw_gfchw
2982+
cpp_class_name: Conv2DNgchwGfchwOp
2983+
doc: |-
2984+
Performs 2-D grouped convolution.
2985+
2986+
Layout:
2987+
* Input: NGCHW.
2988+
* Kernel: GFCHW.
2989+
2990+
Numeric casting is performed on the operands to the inner multiply, promoting
2991+
them to the same data type as the accumulator/output.
2992+
implements:
2993+
- LinalgConvolutionOpInterface
2994+
structured_op: !LinalgStructuredOpConfig
2995+
args:
2996+
- !LinalgOperandDefConfig
2997+
name: I
2998+
kind: input_tensor
2999+
type_var: T1
3000+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3001+
(s0, s1, s2, s3 * s4 + s5 * s6, s7 * s8 + s9 * s10)>
3002+
- !LinalgOperandDefConfig
3003+
name: K
3004+
kind: input_tensor
3005+
type_var: T2
3006+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3007+
(s1, s11, s2, s5, s9)>
3008+
- !LinalgOperandDefConfig
3009+
name: O
3010+
kind: output_tensor
3011+
type_var: U
3012+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3013+
(s0, s1, s11, s3, s7)>
29153014
- !LinalgOperandDefConfig
29163015
name: strides
29173016
kind: index_attr

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def conv_2d_ngchw_fgchw(
780780
T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
781781
),
782782
K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW),
783-
O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True),
783+
O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
784784
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
785785
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
786786
):
@@ -790,6 +790,32 @@ def conv_2d_ngchw_fgchw(
790790
* Input: NGCHW.
791791
* Kernel: FGCHW.
792792
793+
Numeric casting is performed on the operands to the inner multiply, promoting
794+
them to the same data type as the accumulator/output.
795+
"""
796+
implements(ConvolutionOpInterface)
797+
domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw)
798+
O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed(
799+
U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]
800+
) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw])
801+
802+
803+
@linalg_structured_op
804+
def conv_2d_ngchw_gfchw(
805+
I=TensorDef(
806+
T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW
807+
),
808+
K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW),
809+
O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True),
810+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
811+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
812+
):
813+
"""Performs 2-D grouped convolution.
814+
815+
Layout:
816+
* Input: NGCHW.
817+
* Kernel: GFCHW.
818+
793819
Numeric casting is performed on the operands to the inner multiply, promoting
794820
them to the same data type as the accumulator/output.
795821
"""

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,38 @@ func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x
409409

410410
// -----
411411

412+
// CHECK-LABEL: func @conv_2d_ngchw_fgchw_dimensions
413+
func.func @conv_2d_ngchw_fgchw_dimensions(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<2x5x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
414+
// CHECK: linalg.conv_2d_ngchw_fgchw
415+
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
416+
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
417+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
418+
// CHECK-SAME: outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
419+
%0 = linalg.conv_2d_ngchw_fgchw {dilations = dense<1> : tensor<2xi64>,
420+
strides = dense<1> : tensor<2xi64>}
421+
ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<2x5x3x3x3xf32>)
422+
outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
423+
return %0 : tensor<1x5x2x30x30xf32>
424+
}
425+
426+
// -----
427+
428+
// CHECK-LABEL: func @conv_2d_ngchw_gfchw
429+
func.func @conv_2d_ngchw_gfchw(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<5x2x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
430+
// CHECK: linalg.conv_2d_ngchw_gfchw
431+
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
432+
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
433+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
434+
// CHECK-SAME: outs(%{{.+}} : tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
435+
%0 = linalg.conv_2d_ngchw_gfchw {dilations = dense<1> : tensor<2xi64>,
436+
strides = dense<1> : tensor<2xi64>}
437+
ins (%input, %filter: tensor<1x5x3x32x32xf32>, tensor<5x2x3x3x3xf32>)
438+
outs (%init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32>
439+
return %0 : tensor<1x5x2x30x30xf32>
440+
}
441+
442+
// -----
443+
412444
// CHECK-LABEL: func @conv_3d_ndhwc_dhwcf
413445
func.func @conv_3d_ndhwc_dhwcf(%input: tensor<?x?x?x?x?xf32>, %filter: tensor<?x?x?x?x?xf32>, %init: tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32> {
414446
// CHECK: %{{.+}} = linalg.conv_3d_ndhwc_dhwcf

0 commit comments

Comments
 (0)