@@ -74,6 +74,14 @@ func.func @conv_2d_nchw_fchw_CCCC(%arg0: tensor<?x?x?x?xf32, #CCCC>, %arg1: tens
74
74
return %ret : tensor <?x?x?x?xf32 >
75
75
}
76
76
77
+ func.func @conv_2d_nchw_fchw_CCCC_CCCC (%arg0: tensor <?x?x?x?xf32 , #CCCC >, %arg1: tensor <?x?x?x?xf32 , #CCCC >, %arg2: tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 > {
78
+ %ret = linalg.conv_2d_nchw_fchw {dilations = dense <1 > : tensor <2 xi64 >,
79
+ strides = dense <1 > : tensor <2 xi64 >}
80
+ ins (%arg0 , %arg1: tensor <?x?x?x?xf32 , #CCCC >, tensor <?x?x?x?xf32 , #CCCC >)
81
+ outs (%arg2: tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 >
82
+ return %ret : tensor <?x?x?x?xf32 >
83
+ }
84
+
77
85
func.func @entry () {
78
86
%c0 = arith.constant 0 : index
79
87
%c1 = arith.constant 1 : index
@@ -96,9 +104,13 @@ func.func @entry() {
96
104
%in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
97
105
: tensor <?x?x?x?xf32 > to tensor <?x?x?x?xf32 , #CCCC >
98
106
107
+ %filter2D_nhwc_CCCC = sparse_tensor.convert %filter2D_nhwc
108
+ : tensor <?x?x?x?xf32 > to tensor <?x?x?x?xf32 , #CCCC >
109
+
99
110
%dense_ret = call @conv_2d_nchw_fchw (%in2D_nhwc , %filter2D_nhwc , %out2D_nhwc ) : (tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
100
111
%CCCC_ret = call @conv_2d_nchw_fchw_CDCD (%in2D_nhwc_CCCD , %filter2D_nhwc , %out2D_nhwc_CCCD ) : (tensor <?x?x?x?xf32 , #CDCD >, tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
101
112
%CDCD_ret = call @conv_2d_nchw_fchw_CCCC (%in2D_nhwc_CCCC , %filter2D_nhwc , %out2D_nhwc_CCCC ) : (tensor <?x?x?x?xf32 , #CCCC >, tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
113
+ %dual_CCCC_ret = call @conv_2d_nchw_fchw_CCCC_CCCC (%in2D_nhwc_CCCC , %filter2D_nhwc_CCCC , %out2D_nhwc ) : (tensor <?x?x?x?xf32 , #CCCC >, tensor <?x?x?x?xf32 , #CCCC >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
102
114
103
115
104
116
// CHECK: ( ( ( ( 108, 124, 124, 124, 108, 108 ),
@@ -167,6 +179,28 @@ func.func @entry() {
167
179
: tensor <?x?x?x?xf32 >, vector <3 x1 x6 x6 xf32 >
168
180
vector.print %v2 : vector <3 x1 x6 x6 xf32 >
169
181
182
+ // CHECK: ( ( ( ( 108, 124, 124, 124, 108, 108 ),
183
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
184
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
185
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
186
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
187
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) ),
188
+ // CHECK-SAME: ( ( ( 108, 108, 108, 108, 108, 108 ),
189
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
190
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
191
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
192
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
193
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) ),
194
+ // CHECK-SAME: ( ( ( 108, 108, 108, 108, 108, 108 ),
195
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
196
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
197
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
198
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ),
199
+ // CHECK-SAME: ( 108, 108, 108, 108, 108, 108 ) ) ) )
200
+ %v3 = vector.transfer_read %dual_CCCC_ret [%c0 , %c0 , %c0 , %c0 ], %zero
201
+ : tensor <?x?x?x?xf32 >, vector <3 x1 x6 x6 xf32 >
202
+ vector.print %v3 : vector <3 x1 x6 x6 xf32 >
203
+
170
204
// Free the resources
171
205
bufferization.dealloc_tensor %in2D_nhwc : tensor <?x?x?x?xf32 >
172
206
bufferization.dealloc_tensor %filter2D_nhwc : tensor <?x?x?x?xf32 >
@@ -176,5 +210,6 @@ func.func @entry() {
176
210
177
211
bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor <?x?x?x?xf32 , #CCCC >
178
212
bufferization.dealloc_tensor %in2D_nhwc_CCCD : tensor <?x?x?x?xf32 , #CDCD >
213
+ bufferization.dealloc_tensor %filter2D_nhwc_CCCC : tensor <?x?x?x?xf32 , #CCCC >
179
214
return
180
215
}
0 commit comments