Skip to content

Commit 71fa1a1

Browse files
stefankoncarevicGroverkss
authored andcommitted
[mlir][linalg] Add Grouped Convolution Ops: conv_2d_nhwgc_gfhwc and conv_2d_nhwgc_gfhwc_q (llvm#108192)
This patch adds two new ops: linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp, and uses them to convert tosa group conv2d Ops. - Added linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp. - Updated the conversion process to use these new ops for tosa group conv2d operations.
1 parent a5f689b commit 71fa1a1

File tree

3 files changed

+330
-0
lines changed

3 files changed

+330
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3475,6 +3475,243 @@ structured_op: !LinalgStructuredOpConfig
34753475
- !ScalarExpression
34763476
scalar_arg: K
34773477
--- !LinalgOpConfig
3478+
metadata: !LinalgOpMetadata
3479+
name: conv_2d_nhwgc_gfhwc
3480+
cpp_class_name: Conv2DNhwgcGfhwcOp
3481+
doc: |-
3482+
Performs 2-D grouped convolution.
3483+
3484+
Layout:
3485+
* Input: NHWGC.
3486+
* Kernel: GFHWC.
3487+
3488+
Numeric casting is performed on the operands to the inner multiply, promoting
3489+
them to the same data type as the accumulator/output.
3490+
implements:
3491+
- LinalgConvolutionOpInterface
3492+
structured_op: !LinalgStructuredOpConfig
3493+
args:
3494+
- !LinalgOperandDefConfig
3495+
name: I
3496+
kind: input_tensor
3497+
type_var: T1
3498+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3499+
(s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
3500+
- !LinalgOperandDefConfig
3501+
name: K
3502+
kind: input_tensor
3503+
type_var: T2
3504+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3505+
(s9, s11, s3, s7, s10)>
3506+
- !LinalgOperandDefConfig
3507+
name: O
3508+
kind: output_tensor
3509+
type_var: U
3510+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3511+
(s0, s1, s5, s9, s11)>
3512+
- !LinalgOperandDefConfig
3513+
name: strides
3514+
kind: index_attr
3515+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3516+
-> (s2, s6)>
3517+
default_indices:
3518+
- 1
3519+
- 1
3520+
- !LinalgOperandDefConfig
3521+
name: dilations
3522+
kind: index_attr
3523+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3524+
-> (s4, s8)>
3525+
default_indices:
3526+
- 1
3527+
- 1
3528+
indexing_maps: !LinalgIndexingMapsConfig
3529+
static_indexing_maps:
3530+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3531+
s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
3532+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3533+
s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
3534+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3535+
s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
3536+
iterator_types:
3537+
- parallel
3538+
- parallel
3539+
- parallel
3540+
- parallel
3541+
- parallel
3542+
- reduction
3543+
- reduction
3544+
- reduction
3545+
assignments:
3546+
- !ScalarAssign
3547+
arg: O
3548+
value: !ScalarExpression
3549+
scalar_fn:
3550+
kind: binary
3551+
fn_name: add
3552+
operands:
3553+
- !ScalarExpression
3554+
scalar_arg: O
3555+
- !ScalarExpression
3556+
scalar_fn:
3557+
kind: binary
3558+
fn_name: mul
3559+
operands:
3560+
- !ScalarExpression
3561+
scalar_fn:
3562+
kind: type
3563+
fn_name: cast_signed
3564+
type_var: U
3565+
operands:
3566+
- !ScalarExpression
3567+
scalar_arg: I
3568+
- !ScalarExpression
3569+
scalar_fn:
3570+
kind: type
3571+
fn_name: cast_signed
3572+
type_var: U
3573+
operands:
3574+
- !ScalarExpression
3575+
scalar_arg: K
3576+
--- !LinalgOpConfig
3577+
metadata: !LinalgOpMetadata
3578+
name: conv_2d_nhwgc_gfhwc_q
3579+
cpp_class_name: Conv2DNhwgcGfhwcQOp
3580+
doc: |-
3581+
Performs 2-D grouped convolution with zero point offsets.
3582+
3583+
Layout:
3584+
* Input: NHWGC.
3585+
* Kernel: GFHWC.
3586+
3587+
Numeric casting is performed on the operands to the inner multiply, promoting
3588+
them to the same data type as the accumulator/output. This includes the zero
3589+
point offsets common to quantized operations.
3590+
implements:
3591+
- LinalgConvolutionOpInterface
3592+
structured_op: !LinalgStructuredOpConfig
3593+
args:
3594+
- !LinalgOperandDefConfig
3595+
name: I
3596+
kind: input_tensor
3597+
type_var: T1
3598+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3599+
(s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
3600+
- !LinalgOperandDefConfig
3601+
name: K
3602+
kind: input_tensor
3603+
type_var: T2
3604+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3605+
(s9, s11, s3, s7, s10)>
3606+
- !LinalgOperandDefConfig
3607+
name: IZp
3608+
kind: scalar
3609+
type_var: I32
3610+
- !LinalgOperandDefConfig
3611+
name: KZp
3612+
kind: scalar
3613+
type_var: I32
3614+
- !LinalgOperandDefConfig
3615+
name: O
3616+
kind: output_tensor
3617+
type_var: U
3618+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
3619+
(s0, s1, s5, s9, s11)>
3620+
- !LinalgOperandDefConfig
3621+
name: strides
3622+
kind: index_attr
3623+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3624+
-> (s2, s6)>
3625+
default_indices:
3626+
- 1
3627+
- 1
3628+
- !LinalgOperandDefConfig
3629+
name: dilations
3630+
kind: index_attr
3631+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
3632+
-> (s4, s8)>
3633+
default_indices:
3634+
- 1
3635+
- 1
3636+
indexing_maps: !LinalgIndexingMapsConfig
3637+
static_indexing_maps:
3638+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3639+
s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
3640+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3641+
s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
3642+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3643+
s8, s9, s10, s11] -> ()>
3644+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3645+
s8, s9, s10, s11] -> ()>
3646+
- affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
3647+
s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
3648+
iterator_types:
3649+
- parallel
3650+
- parallel
3651+
- parallel
3652+
- parallel
3653+
- parallel
3654+
- reduction
3655+
- reduction
3656+
- reduction
3657+
assignments:
3658+
- !ScalarAssign
3659+
arg: O
3660+
value: !ScalarExpression
3661+
scalar_fn:
3662+
kind: binary
3663+
fn_name: add
3664+
operands:
3665+
- !ScalarExpression
3666+
scalar_arg: O
3667+
- !ScalarExpression
3668+
scalar_fn:
3669+
kind: binary
3670+
fn_name: mul
3671+
operands:
3672+
- !ScalarExpression
3673+
scalar_fn:
3674+
kind: binary
3675+
fn_name: sub
3676+
operands:
3677+
- !ScalarExpression
3678+
scalar_fn:
3679+
kind: type
3680+
fn_name: cast_signed
3681+
type_var: U
3682+
operands:
3683+
- !ScalarExpression
3684+
scalar_arg: I
3685+
- !ScalarExpression
3686+
scalar_fn:
3687+
kind: type
3688+
fn_name: cast_signed
3689+
type_var: U
3690+
operands:
3691+
- !ScalarExpression
3692+
scalar_arg: IZp
3693+
- !ScalarExpression
3694+
scalar_fn:
3695+
kind: binary
3696+
fn_name: sub
3697+
operands:
3698+
- !ScalarExpression
3699+
scalar_fn:
3700+
kind: type
3701+
fn_name: cast_signed
3702+
type_var: U
3703+
operands:
3704+
- !ScalarExpression
3705+
scalar_arg: K
3706+
- !ScalarExpression
3707+
scalar_fn:
3708+
kind: type
3709+
fn_name: cast_signed
3710+
type_var: U
3711+
operands:
3712+
- !ScalarExpression
3713+
scalar_arg: KZp
3714+
--- !LinalgOpConfig
34783715
metadata: !LinalgOpMetadata
34793716
name: conv_2d_ngchw_gfchw_q
34803717
cpp_class_name: Conv2DNgchwGfchwQOp

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,67 @@ def conv_2d_ngchw_gfchw(
964964
) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw])
965965

966966

967+
@linalg_structured_op
968+
def conv_2d_nhwgc_gfhwc(
969+
I=TensorDef(
970+
T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C
971+
),
972+
K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C),
973+
O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True),
974+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
975+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
976+
):
977+
"""Performs 2-D grouped convolution.
978+
979+
Layout:
980+
* Input: NHWGC.
981+
* Kernel: GFHWC.
982+
983+
Numeric casting is performed on the operands to the inner multiply, promoting
984+
them to the same data type as the accumulator/output.
985+
"""
986+
implements(ConvolutionOpInterface)
987+
domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c)
988+
O[D.n, D.oh, D.ow, D.g, D.fg] += TypeFn.cast_signed(
989+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c]
990+
) * TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c])
991+
992+
993+
@linalg_structured_op
994+
def conv_2d_nhwgc_gfhwc_q(
995+
I=TensorDef(
996+
T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C
997+
),
998+
K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C),
999+
IZp=ScalarDef(I32),
1000+
KZp=ScalarDef(I32),
1001+
O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True),
1002+
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
1003+
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
1004+
):
1005+
"""Performs 2-D grouped convolution with zero point offsets.
1006+
1007+
Layout:
1008+
* Input: NHWGC.
1009+
* Kernel: GFHWC.
1010+
1011+
Numeric casting is performed on the operands to the inner multiply, promoting
1012+
them to the same data type as the accumulator/output. This includes the zero
1013+
point offsets common to quantized operations.
1014+
"""
1015+
implements(ConvolutionOpInterface)
1016+
domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c)
1017+
O[D.n, D.oh, D.ow, D.g, D.fg] += (
1018+
TypeFn.cast_signed(
1019+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c]
1020+
)
1021+
- TypeFn.cast_signed(U, IZp)
1022+
) * (
1023+
TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c])
1024+
- TypeFn.cast_signed(U, KZp)
1025+
)
1026+
1027+
9671028
@linalg_structured_op
9681029
def conv_2d_ngchw_gfchw_q(
9691030
I=TensorDef(

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,38 @@ func.func @conv_2d_ngchw_fgchw(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x
409409

410410
// -----
411411

412+
// CHECK-LABEL: func @conv_2d_nhwgc_gfhwc
413+
func.func @conv_2d_nhwgc_gfhwc(%input: memref<?x?x?x?x?xf32>, %filter: memref<?x?x?x?x?xf32>, %output: memref<?x?x?x?x?xf32>) {
414+
// CHECK: linalg.conv_2d_nhwgc_gfhwc
415+
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
416+
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
417+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
418+
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?x?x?xf32>)
419+
linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
420+
strides = dense<1> : tensor<2xi64>}
421+
ins (%input, %filter: memref<?x?x?x?x?xf32>, memref<?x?x?x?x?xf32>)
422+
outs (%output: memref<?x?x?x?x?xf32>)
423+
return
424+
}
425+
426+
// -----
427+
428+
// CHECK-LABEL: func @conv_2d_nhwgc_gfhwc_tensor
429+
func.func @conv_2d_nhwgc_gfhwc_tensor(%input: tensor<1x28x28x2x3xf32>, %filter: tensor<2x8x3x3x3xf32>, %output: tensor<1x26x26x2x8xf32>) -> tensor<1x26x26x2x8xf32> {
430+
// CHECK: linalg.conv_2d_nhwgc_gfhwc
431+
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
432+
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
433+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x28x28x2x3xf32>, tensor<2x8x3x3x3xf32>)
434+
// CHECK-SAME: outs(%{{.+}} : tensor<1x26x26x2x8xf32>) -> tensor<1x26x26x2x8xf32>
435+
%0 = linalg.conv_2d_nhwgc_gfhwc {dilations = dense<1> : tensor<2xi64>,
436+
strides = dense<1> : tensor<2xi64>}
437+
ins (%input, %filter: tensor<1x28x28x2x3xf32>, tensor<2x8x3x3x3xf32>)
438+
outs (%output: tensor<1x26x26x2x8xf32>) -> tensor<1x26x26x2x8xf32>
439+
return %0 : tensor<1x26x26x2x8xf32>
440+
}
441+
442+
// -----
443+
412444
// CHECK-LABEL: func @conv_2d_ngchw_fgchw_dimensions
413445
func.func @conv_2d_ngchw_fgchw_dimensions(%input: tensor<1x5x3x32x32xf32>, %filter: tensor<2x5x3x3x3xf32>, %init: tensor<1x5x2x30x30xf32>) -> tensor<1x5x2x30x30xf32> {
414446
// CHECK: linalg.conv_2d_ngchw_fgchw

0 commit comments

Comments
 (0)