Skip to content

Commit 3005ca2

Browse files
authored
[mlir][linalg] Add a test for inferConvolutionDimsImpl (#90057)
Adds a test for `inferConvolutionDimsImpl` to exercise the logic for depthwise convs.
1 parent 76739d1 commit 3005ca2

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

mlir/test/Dialect/Linalg/match-ops-interpreter.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,28 @@ module attributes { transform.target_tag = "start_here" } {
10621062
return %result : tensor<10x18x15xf64>
10631063
}
10641064

1065+
func.func @convolution_depthwise(%input: tensor<1x10x196x48xf32>, %filter: tensor<1x4x48xf32>) -> tensor<1x10x191x48xf32> {
1066+
%cst = arith.constant 0.0 : f32
1067+
%empty = tensor.empty() : tensor<1x10x191x48xf32>
1068+
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
1069+
// expected-remark @below {{convolution}}
1070+
// expected-remark @below {{batch dims 0}}
1071+
// expected-remark @below {{output image dims 1 : i64, 2 : i64}}
1072+
// expected-remark @below {{output channel dims}}
1073+
// expected-remark @below {{filter loop dims 4 : i64, 5 : i64}}
1074+
// expected-remark @below {{input channel dims}}
1075+
// expected-remark @below {{depth dims 3}}
1076+
// expected-remark @below {{strides 1 : i64, 1 : i64}}
1077+
// expected-remark @below {{dilations 1 : i64, 1 : i64}}
1078+
%result = linalg.depthwise_conv_2d_nhwc_hwc {
1079+
dilations = dense<1> : tensor<2xi64>,
1080+
strides = dense<1> : tensor<2xi64>}
1081+
ins(%input, %filter : tensor<1x10x196x48xf32>, tensor<1x4x48xf32>)
1082+
outs(%fill : tensor<1x10x191x48xf32>) -> tensor<1x10x191x48xf32>
1083+
1084+
return %result : tensor<1x10x191x48xf32>
1085+
}
1086+
10651087
func.func @convolution_multi_channel(%input: tensor<2x34x68x16xf32>, %filter: tensor<8x2x3x5x16x16xf32>) -> tensor<8x32x32x16xf32> {
10661088
%cst = arith.constant 0.0 : f32
10671089
%empty = tensor.empty() : tensor<8x32x32x16xf32>

0 commit comments

Comments
 (0)