Skip to content

Commit 6ca0b27

Browse files
author
Peiming Liu
committed
[mlir][sparse] more complicated test for dual sparse convolution kernel.
Reviewed By: anlunx Differential Revision: https://reviews.llvm.org/D158443
1 parent bfe390c commit 6ca0b27

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_conv_2d_nchw_fchw.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ func.func @conv_2d_nchw_fchw_CCCC(%arg0: tensor<?x?x?x?xf32, #CCCC>, %arg1: tens
7474
return %ret : tensor<?x?x?x?xf32>
7575
}
7676

77+
func.func @conv_2d_nchw_fchw_CCCC_CCCC(%arg0: tensor<?x?x?x?xf32, #CCCC>, %arg1: tensor<?x?x?x?xf32, #CCCC>, %arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
78+
%ret = linalg.conv_2d_nchw_fchw {dilations = dense<1> : tensor<2xi64>,
79+
strides = dense<1> : tensor<2xi64>}
80+
ins (%arg0, %arg1: tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32, #CCCC>)
81+
outs (%arg2: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
82+
return %ret : tensor<?x?x?x?xf32>
83+
}
84+
7785
func.func @entry() {
7886
%c0 = arith.constant 0 : index
7987
%c1 = arith.constant 1 : index
@@ -96,9 +104,13 @@ func.func @entry() {
96104
%in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
97105
: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CCCC>
98106

107+
%filter2D_nhwc_CCCC = sparse_tensor.convert %filter2D_nhwc
108+
: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32, #CCCC>
109+
99110
%dense_ret = call @conv_2d_nchw_fchw(%in2D_nhwc, %filter2D_nhwc, %out2D_nhwc) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
100111
%CCCC_ret = call @conv_2d_nchw_fchw_CDCD(%in2D_nhwc_CCCD, %filter2D_nhwc, %out2D_nhwc_CCCD) : (tensor<?x?x?x?xf32, #CDCD>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
101112
%CDCD_ret = call @conv_2d_nchw_fchw_CCCC(%in2D_nhwc_CCCC, %filter2D_nhwc, %out2D_nhwc_CCCC) : (tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
113+
%dual_CCCC_ret = call @conv_2d_nchw_fchw_CCCC_CCCC(%in2D_nhwc_CCCC, %filter2D_nhwc_CCCC, %out2D_nhwc) : (tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32, #CCCC>, tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>)
102114

103115

104116
// CHECK: ( ( ( ( 108, 124, 124, 124, 108, 108 ),
@@ -167,6 +179,28 @@ func.func @entry() {
167179
: tensor<?x?x?x?xf32>, vector<3x1x6x6xf32>
168180
vector.print %v2 : vector<3x1x6x6xf32>
169181

182+
// CHECK: ( ( ( ( 108, 124, 124, 124, 108, 108 ),
183+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
184+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
185+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
186+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
187+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) ),
188+
// CHECK-SAME: ( ( ( 108, 108, 108, 108, 108, 108 ),
189+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
190+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
191+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
192+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
193+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) ),
194+
// CHECK-SAME: ( ( ( 108, 108, 108, 108, 108, 108 ),
195+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
196+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
197+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
198+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
199+
// CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) ) )
200+
%v3 = vector.transfer_read %dual_CCCC_ret[%c0, %c0, %c0, %c0], %zero
201+
: tensor<?x?x?x?xf32>, vector<3x1x6x6xf32>
202+
vector.print %v3 : vector<3x1x6x6xf32>
203+
170204
// Free the resources
171205
bufferization.dealloc_tensor %in2D_nhwc : tensor<?x?x?x?xf32>
172206
bufferization.dealloc_tensor %filter2D_nhwc : tensor<?x?x?x?xf32>
@@ -176,5 +210,6 @@ func.func @entry() {
176210

177211
bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor<?x?x?x?xf32, #CCCC>
178212
bufferization.dealloc_tensor %in2D_nhwc_CCCD : tensor<?x?x?x?xf32, #CDCD>
213+
bufferization.dealloc_tensor %filter2D_nhwc_CCCC : tensor<?x?x?x?xf32, #CCCC>
179214
return
180215
}

0 commit comments

Comments
 (0)