Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit deebf18

Browse files
cathyzhyiTobias Gysi
authored andcommitted
[mlir][linalg] Add pooling_nchw_max, conv_2d_nchw as yaml ops.
- Add pooling_nchw_max. - Move conv_2d_nchw to yaml ops and add strides and dilation attributes. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D106658
1 parent ae69f46 commit deebf18

File tree

6 files changed

+229
-10
lines changed

6 files changed

+229
-10
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,88 @@ structured_op: !LinalgStructuredOpConfig
905905
- !ScalarExpression
906906
scalar_arg: K
907907
--- !LinalgOpConfig
908+
metadata: !LinalgOpMetadata
909+
name: conv_2d_nchw
910+
cpp_class_name: Conv2DNchwOp
911+
doc: |-
912+
Performs 2-D convolution.
913+
914+
Numeric casting is performed on the operands to the inner multiply, promoting
915+
them to the same data type as the accumulator/output.
916+
structured_op: !LinalgStructuredOpConfig
917+
args:
918+
- !LinalgOperandDefConfig
919+
name: I
920+
usage: InputOperand
921+
type_var: T1
922+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
923+
-> (s0, s1, s2, s3)>
924+
- !LinalgOperandDefConfig
925+
name: K
926+
usage: InputOperand
927+
type_var: T2
928+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
929+
-> (s4, s1, s5, s6)>
930+
- !LinalgOperandDefConfig
931+
name: O
932+
usage: OutputOperand
933+
type_var: U
934+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12]
935+
-> (s0, s4, s7, s8, s1)>
936+
- !LinalgOperandDefConfig
937+
name: strides
938+
usage: IndexAttribute
939+
type_var: I64
940+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
941+
s12] -> (s9, s10)>
942+
- !LinalgOperandDefConfig
943+
name: dilations
944+
usage: IndexAttribute
945+
type_var: I64
946+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11,
947+
s12] -> (s11, s12)>
948+
indexing_maps: !LinalgIndexingMapsConfig
949+
static_indexing_maps:
950+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
951+
s9, s10, s11, s12] -> (d0, d4, d2 * s9 + d5 * s11, d3 * s10 + d6 * s12)>
952+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
953+
s9, s10, s11, s12] -> (d1, d4, d5, d6)>
954+
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
955+
s9, s10, s11, s12] -> (d0, d1, d2, d3)>
956+
iterator_types:
957+
- parallel
958+
- parallel
959+
- parallel
960+
- parallel
961+
- reduction
962+
- reduction
963+
- reduction
964+
assignments:
965+
- !ScalarAssign
966+
arg: O
967+
value: !ScalarExpression
968+
scalar_apply:
969+
fn_name: add
970+
operands:
971+
- !ScalarExpression
972+
scalar_arg: O
973+
- !ScalarExpression
974+
scalar_apply:
975+
fn_name: mul
976+
operands:
977+
- !ScalarExpression
978+
symbolic_cast:
979+
type_var: U
980+
operands:
981+
- !ScalarExpression
982+
scalar_arg: I
983+
- !ScalarExpression
984+
symbolic_cast:
985+
type_var: U
986+
operands:
987+
- !ScalarExpression
988+
scalar_arg: K
989+
--- !LinalgOpConfig
908990
metadata: !LinalgOpMetadata
909991
name: pooling_nhwc_sum
910992
cpp_class_name: PoolingNhwcSumOp
@@ -1047,6 +1129,77 @@ structured_op: !LinalgStructuredOpConfig
10471129
- !ScalarExpression
10481130
scalar_arg: I
10491131
--- !LinalgOpConfig
1132+
metadata: !LinalgOpMetadata
1133+
name: pooling_nchw_max
1134+
cpp_class_name: PoolingNchwMaxOp
1135+
doc: |-
1136+
Performs max pooling.
1137+
1138+
Numeric casting is performed on the input operand, promoting it to the same
1139+
data type as the accumulator/output.
1140+
structured_op: !LinalgStructuredOpConfig
1141+
args:
1142+
- !LinalgOperandDefConfig
1143+
name: I
1144+
usage: InputOperand
1145+
type_var: T1
1146+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
1147+
(s0, s1, s2, s3)>
1148+
- !LinalgOperandDefConfig
1149+
name: K
1150+
usage: InputOperand
1151+
type_var: T2
1152+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
1153+
(s4, s5)>
1154+
- !LinalgOperandDefConfig
1155+
name: O
1156+
usage: OutputOperand
1157+
type_var: U
1158+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
1159+
(s0, s1, s6, s7)>
1160+
- !LinalgOperandDefConfig
1161+
name: strides
1162+
usage: IndexAttribute
1163+
type_var: I64
1164+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
1165+
-> (s8, s9)>
1166+
- !LinalgOperandDefConfig
1167+
name: dilations
1168+
usage: IndexAttribute
1169+
type_var: I64
1170+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
1171+
-> (s10, s11)>
1172+
indexing_maps: !LinalgIndexingMapsConfig
1173+
static_indexing_maps:
1174+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
1175+
s10, s11] -> (d0, d1, d2 * s8 + d4 * s10, d3 * s9 + d5 * s11)>
1176+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
1177+
s10, s11] -> (d4, d5)>
1178+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
1179+
s10, s11] -> (d0, d1, d2, d3)>
1180+
iterator_types:
1181+
- parallel
1182+
- parallel
1183+
- parallel
1184+
- parallel
1185+
- reduction
1186+
- reduction
1187+
assignments:
1188+
- !ScalarAssign
1189+
arg: O
1190+
value: !ScalarExpression
1191+
scalar_apply:
1192+
fn_name: max
1193+
operands:
1194+
- !ScalarExpression
1195+
scalar_arg: O
1196+
- !ScalarExpression
1197+
symbolic_cast:
1198+
type_var: U
1199+
operands:
1200+
- !ScalarExpression
1201+
scalar_arg: I
1202+
--- !LinalgOpConfig
10501203
metadata: !LinalgOpMetadata
10511204
name: pooling_nhwc_min
10521205
cpp_class_name: PoolingNhwcMinOp

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,6 @@ def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F
125125
O(n, h, w, f), MulFOp(I(n, h + kh, w + kw, c), K(f, kh, kw, c)));
126126
}
127127

128-
ods_def<ConvNCHWOp>:
129-
def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) {
130-
O(n, f, h, w) = AddFOp<kh, kw>(
131-
O(n, f, h, w), MulFOp(I(n, c, h + kh, w + kw), K(f, c, kh, kw)));
132-
}
133-
134128
ods_def<ConvDHWOp>:
135129
def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) {
136130
O(d, h, w) = AddFOp<kd, kh, kw>(

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1186,8 +1186,8 @@ void mlir::linalg::populateConvVectorizationPatterns(
11861186
populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
11871187
tiling, promotion, vectorization, tileSizes);
11881188

1189-
populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
1190-
tileSizes);
1189+
populateVectorizationPatterns<Conv2DNchwOp, 4>(tiling, promotion,
1190+
vectorization, tileSizes);
11911191
populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
11921192
tiling, promotion, vectorization, tileSizes);
11931193

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,23 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
205205
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
206206
D.c]) * cast(U, K[D.kh, D.kw, D.c])
207207

208+
@linalg_structured_op
209+
def conv_2d_nchw(
210+
I=TensorDef(T1, S.N, S.C, S.IH, S.IW),
211+
K=TensorDef(T2, S.F, S.C, S.KH, S.KW),
212+
O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True),
213+
strides=AttributeDef(S.SH, S.SW),
214+
dilations=AttributeDef(S.DH, S.DW)):
215+
"""Performs 2-D convolution.
216+
217+
Numeric casting is performed on the operands to the inner multiply, promoting
218+
them to the same data type as the accumulator/output.
219+
"""
220+
domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw)
221+
O[D.n, D.f, D.oh, D.ow] += cast(
222+
U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
223+
]) * cast(U, K[D.f, D.c, D.kh, D.kw])
224+
208225

209226
@linalg_structured_op
210227
def pooling_nhwc_sum(
@@ -240,6 +257,22 @@ def pooling_nhwc_max(
240257
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
241258
D.c]))
242259

260+
@linalg_structured_op
261+
def pooling_nchw_max(
262+
I=TensorDef(T1, S.N, S.C, S.H, S.W),
263+
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
264+
O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True),
265+
strides=AttributeDef(S.SH, S.SW),
266+
dilations=AttributeDef(S.DH, S.DW)):
267+
"""Performs max pooling.
268+
269+
Numeric casting is performed on the input operand, promoting it to the same
270+
data type as the accumulator/output.
271+
"""
272+
domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw)
273+
O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)(
274+
cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
275+
]))
243276

244277
@linalg_structured_op
245278
def pooling_nhwc_min(

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ func @depthwise_conv_2d_input_nhwc_filter_hwcf_tensor(%input: tensor<2x4x5x2xf32
3030
return %0 : tensor<2x3x4x2x3xf32>
3131
}
3232

33+
// CHECK-LABEL: func @conv_2d_nchw_tensor
34+
func @conv_2d_nchw_tensor(%input: tensor<2x2x4x5xf32>, %filter: tensor<4x2x3x3xf32>) -> tensor<2x4x2x3xf32> {
35+
%cst = constant 0.000000e+00 : f32
36+
%init = linalg.init_tensor [2, 4, 2, 3] : tensor<2x4x2x3xf32>
37+
%fill = linalg.fill(%cst, %init) : f32, tensor<2x4x2x3xf32> -> tensor<2x4x2x3xf32>
38+
// CHECK: %{{.+}} = linalg.conv_2d_nchw
39+
// CHECK-SAME: {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
40+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<2x2x4x5xf32>, tensor<4x2x3x3xf32>)
41+
// CHECK-SAME: outs(%{{.+}} : tensor<2x4x2x3xf32>) -> tensor<2x4x2x3xf32>
42+
// CHECK: return %{{.+}} : tensor<2x4x2x3xf32>
43+
// CHECK: }
44+
%0 = linalg.conv_2d_nchw
45+
{dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
46+
ins(%input, %filter: tensor<2x2x4x5xf32>, tensor<4x2x3x3xf32>)
47+
outs(%fill : tensor<2x4x2x3xf32>) -> tensor<2x4x2x3xf32>
48+
return %0 : tensor<2x4x2x3xf32>
49+
}
50+
3351
// CHECK-LABEL: func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref
3452
func @depthwise_conv_2d_input_nhwc_filter_hwcf_memref(%input: memref<2x4x5x2xf32>, %filter: memref<2x2x2x3xf32>, %output: memref<2x3x4x2x3xf32>) {
3553
// CHECK: linalg.depthwise_conv_2d_input_nhwc_filter_hwcf
@@ -381,6 +399,25 @@ func @pooling_nhwc_max_tensor(%input: tensor<1x4x4x1xf32>) -> tensor<1x2x2x1xf32
381399
return %res : tensor<1x2x2x1xf32>
382400
}
383401

402+
// -----
403+
// CHECK-LABEL: func @pooling_nchw_max_tensor
404+
// CHECK: %{{.+}} = linalg.pooling_nchw_max
405+
// CHECK-SAME: dilations = dense<1> : tensor<2xi64>
406+
// CHECK-SAME: strides = dense<1> : tensor<2xi64>
407+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<1x1x4x4xf32>, tensor<3x3xf32>)
408+
// CHECK-SAME: outs(%{{.+}} : tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
409+
410+
func @pooling_nchw_max_tensor(%input: tensor<1x1x4x4xf32>) -> tensor<1x1x2x2xf32> {
411+
%fake = linalg.init_tensor [3, 3] : tensor<3x3xf32>
412+
%init = linalg.init_tensor [1, 1, 2, 2] : tensor<1x1x2x2xf32>
413+
%cst = constant 0.000000e+00 : f32
414+
%fill = linalg.fill(%cst, %init) : f32, tensor<1x1x2x2xf32> -> tensor<1x1x2x2xf32>
415+
%res = linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
416+
ins(%input, %fake: tensor<1x1x4x4xf32>, tensor<3x3xf32>)
417+
outs(%fill: tensor<1x1x2x2xf32>) -> tensor<1x1x2x2xf32>
418+
return %res : tensor<1x1x2x2xf32>
419+
}
420+
384421
// -----
385422

386423
// CHECK-LABEL: func @pooling_nhwc_max

mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@ func @alloc_4d_filled_f32(%s1 : index, %s2 : index, %s3 : index, %s4 : index, %f
3030
}
3131

3232
func @conv_2d_nchw(%arg0: memref<?x?x?x?xf32>, %arg1: memref<?x?x?x?xf32>, %arg2: memref<?x?x?x?xf32>) {
33-
linalg.conv_2d_nchw ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
34-
outs (%arg2: memref<?x?x?x?xf32>)
33+
linalg.conv_2d_nchw
34+
{dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
35+
ins (%arg0, %arg1: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
36+
outs (%arg2: memref<?x?x?x?xf32>)
3537
return
3638
}
3739

0 commit comments

Comments
 (0)