@@ -69,6 +69,14 @@ func.func @conv_2d_nhwc_hwcf_CDCC(%arg0: tensor<?x?x?x?xf32, #CDCC>, %arg1: tens
69
69
return %ret : tensor <?x?x?x?xf32 >
70
70
}
71
71
72
+ func.func @conv_2d_nhwc_hwcf_dual_CDCC (%arg0: tensor <?x?x?x?xf32 , #CDCC >, %arg1: tensor <?x?x?x?xf32 , #CDCC >, %arg2: tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 > {
73
+ %ret = linalg.conv_2d_nhwc_hwcf {dilations = dense <1 > : tensor <2 xi64 >,
74
+ strides = dense <2 > : tensor <2 xi64 >}
75
+ ins (%arg0 , %arg1: tensor <?x?x?x?xf32 , #CDCC >, tensor <?x?x?x?xf32 , #CDCC >)
76
+ outs (%arg2: tensor <?x?x?x?xf32 >) -> tensor <?x?x?x?xf32 >
77
+ return %ret : tensor <?x?x?x?xf32 >
78
+ }
79
+
72
80
73
81
func.func @entry () {
74
82
%c0 = arith.constant 0 : index
@@ -87,16 +95,28 @@ func.func @entry() {
87
95
88
96
%in2D_nhwc_CCCC = sparse_tensor.convert %in2D_nhwc
89
97
: tensor <?x?x?x?xf32 > to tensor <?x?x?x?xf32 , #CCCC >
98
+ %filter2D_nhwc_CDCC = sparse_tensor.convert %filter2D_nhwc
99
+ : tensor <?x?x?x?xf32 > to tensor <?x?x?x?xf32 , #CDCC >
90
100
%in2D_nhwc_CDCC = sparse_tensor.convert %in2D_nhwc
91
101
: tensor <?x?x?x?xf32 > to tensor <?x?x?x?xf32 , #CDCC >
92
102
93
103
%dense_ret = call @conv_2d_nhwc_hwcf (%in2D_nhwc , %filter2D_nhwc , %out2D_nhwc ) : (tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
94
104
%CCCC_ret = call @conv_2d_nhwc_hwcf_CCCC (%in2D_nhwc_CCCC , %filter2D_nhwc , %out2D_nhwc ) : (tensor <?x?x?x?xf32 , #CCCC >, tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
95
105
%CDCC_ret = call @conv_2d_nhwc_hwcf_CDCC (%in2D_nhwc_CDCC , %filter2D_nhwc , %out2D_nhwc ) : (tensor <?x?x?x?xf32 , #CDCC >, tensor <?x?x?x?xf32 >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
96
106
107
+ %dual_CDCC_ret = call @conv_2d_nhwc_hwcf_dual_CDCC (%in2D_nhwc_CDCC , %filter2D_nhwc_CDCC , %out2D_nhwc )
108
+ : (tensor <?x?x?x?xf32 , #CDCC >, tensor <?x?x?x?xf32 , #CDCC >, tensor <?x?x?x?xf32 >) -> (tensor <?x?x?x?xf32 >)
109
+
97
110
// CHECK: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
98
111
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
99
112
// CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
113
+ %v_dual = vector.transfer_read %dual_CDCC_ret [%c0 , %c0 , %c0 , %c0 ], %zero
114
+ : tensor <?x?x?x?xf32 >, vector <3 x3 x3 x1 xf32 >
115
+ vector.print %v_dual : vector <3 x3 x3 x1 xf32 >
116
+
117
+ // CHECK-NEXT: ( ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 20 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
118
+ // CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ),
119
+ // CHECK-SAME: ( ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ), ( ( 0 ), ( 0 ), ( 0 ) ) ) )
100
120
%dense_v = vector.transfer_read %dense_ret [%c0 , %c0 , %c0 , %c0 ], %zero
101
121
: tensor <?x?x?x?xf32 >, vector <3 x3 x3 x1 xf32 >
102
122
vector.print %dense_v : vector <3 x3 x3 x1 xf32 >
@@ -120,6 +140,7 @@ func.func @entry() {
120
140
bufferization.dealloc_tensor %filter2D_nhwc : tensor <?x?x?x?xf32 >
121
141
bufferization.dealloc_tensor %out2D_nhwc : tensor <?x?x?x?xf32 >
122
142
143
+ bufferization.dealloc_tensor %filter2D_nhwc_CDCC : tensor <?x?x?x?xf32 , #CDCC >
123
144
bufferization.dealloc_tensor %in2D_nhwc_CCCC : tensor <?x?x?x?xf32 , #CCCC >
124
145
bufferization.dealloc_tensor %in2D_nhwc_CDCC : tensor <?x?x?x?xf32 , #CDCC >
125
146
return
0 commit comments